diff options
17 files changed, 210 insertions, 52 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index 032ded54e6..a58abcbdff 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -159,13 +159,17 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) } } -Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** cache) { +Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); if (!platform.ok()) { return StreamExecutorUtil::ConvertStatus(platform.status()); } - auto client = - xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()); + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { return client.status(); } @@ -194,8 +198,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { XlaCompilationCache* cache; OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>( rm->default_container(), "xla_cache", &cache, - [this](XlaCompilationCache** cache) { - return BuildCompilationCache(cache); + [this, ctx](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but @@ -264,8 +268,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(&xla_allocator); - run_options.set_inter_op_thread_pool( - ctx->device()->tensorflow_cpu_worker_threads()->workers); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h index 51887cf013..5e4d3336a9 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h @@ -44,7 +44,8 @@ class XlaLocalLaunchOp : public OpKernel { private: // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(XlaCompilationCache** compiler); + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** compiler); DeviceType device_type_; NameAttrList function_; diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index eb9a7ff2ac..8238261e1c 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -43,6 +43,16 @@ int LocalClientOptions::number_of_replicas() const { return number_of_replicas_; } +LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int LocalClientOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -77,6 +87,8 @@ ClientLibrary::~ClientLibrary() = default; ServiceOptions service_options; service_options.set_platform(platform); service_options.set_number_of_replicas(replica_count); + service_options.set_intra_op_parallelism_threads( + options.intra_op_parallelism_threads()); auto instance = MakeUnique<LocalInstance>(); TF_ASSIGN_OR_RETURN(instance->service, diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 49f4541437..3ddd235d0e 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -53,9 +53,14 @@ class LocalClientOptions { LocalClientOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; class ClientLibrary { diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 0de58ea7dc..02cf57e763 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -185,8 +185,15 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run( if (options.allocator() == nullptr) { actual_options.set_allocator(backend_->memory_allocator()); } - ServiceExecutableRunOptions service_options(actual_options, - backend_->StreamBorrower()); + + // For local client execution on CPU backends: + // *) The thread pool used for eigen CPU ops is from + // ExecutableRunOptions.eigen_intra_op_thread_pool. + // *) The thread pool used for XLA CPU ops is from + // backend_->eigen_intra_op_thread_pool(). + ServiceExecutableRunOptions service_options( + actual_options, backend_->StreamBorrower(), + backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { return ExecuteAndDump(&service_options, arguments); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 5c05417c6d..1913617fec 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -41,13 +41,39 @@ namespace se = ::perftools::gputools; namespace xla { +BackendOptions& BackendOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* BackendOptions::platform() const { + return platform_; +} + +BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { + number_of_replicas_ = number_of_replicas; + return *this; +} + +int BackendOptions::number_of_replicas() const { return number_of_replicas_; } + +BackendOptions& BackendOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int BackendOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper() - : pool(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "XLAEigen", - tensorflow::port::NumSchedulableCPUs())), + explicit EigenThreadPoolWrapper(const int num_threads) + : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), + "XLAEigen", num_threads)), wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} @@ -58,18 +84,21 @@ struct Backend::EigenThreadPoolWrapper { }; /* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count) { + const BackendOptions& options) { + int64 replica_count = options.number_of_replicas(); if (replica_count == -1) { legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); replica_count = flags->xla_replicas; } + perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); - std::unique_ptr<Backend> backend(new Backend( - replica_count, platform, compiler, stream_executors, transfer_manager)); + std::unique_ptr<Backend> backend( + new Backend(replica_count, platform, compiler, stream_executors, + transfer_manager, options.intra_op_parallelism_threads())); TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool, backend->default_stream_executor())); return std::move(backend); @@ -79,7 +108,9 @@ struct Backend::EigenThreadPoolWrapper { Backend::CreateDefaultBackend() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); - return CreateBackend(platform); + BackendOptions backend_options; + backend_options.set_platform(platform); + return CreateBackend(backend_options); } tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) { @@ -114,7 +145,7 @@ Backend::Backend( int64 replica_count, perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors, - TransferManager* transfer_manager) + TransferManager* transfer_manager, int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -144,7 +175,11 @@ Backend::Backend( inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( tensorflow::Env::Default(), "xla_inter_op", tensorflow::port::NumSchedulableCPUs())); - intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper()); + const int num_threads = intra_op_parallelism_threads > 0 + ? intra_op_parallelism_threads + : tensorflow::port::NumSchedulableCPUs(); + intra_op_thread_pool_wrapper_.reset( + new EigenThreadPoolWrapper(num_threads)); } } @@ -190,10 +225,17 @@ tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + if (intra_op_thread_pool_wrapper_ == nullptr) { + return nullptr; + } return intra_op_thread_pool_wrapper_->device.get(); } +tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { + if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + return intra_op_thread_pool_wrapper_->pool.get(); +} + StatusOr<perftools::gputools::StreamExecutor*> Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 9f6829b7d9..1068bac277 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -39,6 +39,31 @@ struct ThreadPoolDevice; namespace xla { +// Options to configure the backend when it is created. +class BackendOptions { + public: + // Set the platform backing the backend, or nullptr for the default platform. + BackendOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // Set the number of replicas to use when compiling replicated + // programs. The default is -1 meaning that the value is read from + // the xla_replicas flag. + BackendOptions& set_number_of_replicas(int number_of_replicas); + int number_of_replicas() const; + + // Sets the thread pool size for parallel execution of an individual operator. + // The default value of -1 will result in initializing the thread pool with + // the number of threads equal to the number of cores in the system. + BackendOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; +}; + // Class which encapsulates an XLA backend. It includes everything necessary // to compile and execute computations on a particular platform. // @@ -53,9 +78,9 @@ class Backend { static constexpr int kInitialStreamsToPool = 8; // Creates a new backend for the given platform with the given number of - // replicas. A value of -1 means to use the flag value. + // replicas. static StatusOr<std::unique_ptr<Backend>> CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count = -1); + const BackendOptions& options); // Creates a backend for the default platform. The default platform is defined // in PlatformUtil. @@ -150,6 +175,7 @@ class Backend { // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; + tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const; // Resets the devices associated with this backend. Status ResetDevices(); @@ -160,7 +186,7 @@ class Backend { Compiler* compiler, tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> stream_executors, - TransferManager* transfer_manager); + TransferManager* transfer_manager, int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 7a4723e8d7..cadad10910 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -146,7 +146,7 @@ Status ParallelCpuExecutable::AllocateBuffers( } Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile) { @@ -160,7 +160,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile) { @@ -214,7 +214,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( void** temps_array = buffer_pointers.data(); uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool()); + auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); tensorflow::mutex completion_queue_lock; tensorflow::condition_variable completion_queue_cv; std::deque<HloInstruction*> completion_queue; @@ -251,11 +251,12 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( }); auto function = FindOrDie(functions, instruction); // The thread pool entry takes ownership of |operand_buffers|. + const auto* exec_run_options = &run_options->run_options(); thread_pool->Schedule([instruction, &completion_queue, &completion_queue_lock, &completion_queue_cv, - result_buffer, run_options, operand_buffers, + result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array, function] { - function(result_buffer, run_options, operand_buffers, temps_array, + function(result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array); delete[] operand_buffers; // Push the completed HLO instruction on the queue, the main thread @@ -345,9 +346,8 @@ ParallelCpuExecutable::ExecuteOnStream( const BufferAllocation::Index result_index = result_slice.index(); VLOG(3) << "result index: " << result_index; - TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(), - arguments, device_allocations, - hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunctions( + run_options, arguments, device_allocations, hlo_execution_profile)); // Mark the buffers that are actually live (used in the output) when the // computation finishes executing. @@ -400,8 +400,8 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream( TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, + hlo_execution_profile)); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer which is returned to the caller. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 7223de9f07..6e1239d590 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -96,14 +96,14 @@ class ParallelCpuExecutable : public Executable { // Calls the generated functions in 'function_names_', performing the // computation with the given arguments using the supplied buffers. Status ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments, tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> buffers, diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 8184dbabc8..78d21233c7 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -60,9 +60,12 @@ namespace xla { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr<Backend> backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform) + .set_number_of_replicas(options.number_of_replicas()) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); + TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend, + Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend, CreateComputeConstantBackend()); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index c001e705de..42450dfcae 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -112,6 +112,16 @@ ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } +ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int ServiceOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService( perftools::gputools::Platform* platform) { ServiceOptions default_options; @@ -126,9 +136,10 @@ int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - execute_backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(options.number_of_replicas()); + TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr<Service> service(new Service( @@ -142,7 +153,10 @@ Service::CreateComputeConstantBackend() { PlatformUtil::GetSupportedPlatforms()); for (auto* platform : platforms) { if (platform->id() == se::host::kHostPlatformId) { - return Backend::CreateBackend(platform, /*replica_count=*/1); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(1); + return Backend::CreateBackend(backend_options); } } return NotFound("CPU platform not found"); @@ -573,7 +587,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); + run_options.emplace_back(options, backend->StreamBorrower(), + backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 0e0e7c4e21..05a955137f 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -63,9 +63,14 @@ class ServiceOptions { ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + ServiceOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; // The XLA service object, which is the same across all diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 0d4b214f5f..017e5ef09e 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -30,10 +30,12 @@ class ServiceExecutableRunOptions { using StreamBorrower = std::function<StatusOr<Pool<perftools::gputools::Stream>::SmartPtr>(int)>; - explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, - StreamBorrower borrow_stream = nullptr) + explicit ServiceExecutableRunOptions( + ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)) {} + borrow_stream_(std::move(borrow_stream)), + xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -53,9 +55,15 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } + // Returns reference to thread pool for execution of XLA ops on CPU backend. + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { + return xla_intra_op_thread_pool_; + } + private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d7644a0513..e0c2b9ab09 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -200,11 +200,13 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:pool", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", ], ) @@ -1362,6 +1364,7 @@ cc_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 4e956bc00c..f741ff38b5 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -111,8 +111,9 @@ StatusOr<se::DeviceMemoryBase> HloTestBase::Execute( backend_->eigen_intra_op_thread_pool_device()); HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options(run_options, - backend_->StreamBorrower()); + ServiceExecutableRunOptions service_run_options( + run_options, backend_->StreamBorrower(), + backend_->inter_op_thread_pool()); TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase result, executable->ExecuteOnStream(&service_run_options, arguments, diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 7fe4c9020f..7fcf687655 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -17,12 +17,19 @@ limitations under the License. #include <vector> +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -91,16 +98,34 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { return allocator_; } +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct LocalClientTestBase::EigenThreadPoolWrapper { + explicit EigenThreadPoolWrapper() + : pool(new tensorflow::thread::ThreadPool( + tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)), + wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + device(new Eigen::ThreadPoolDevice(wrapper.get(), + wrapper->NumThreads())) {} + + std::unique_ptr<tensorflow::thread::ThreadPool> pool; + std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper; + std::unique_ptr<Eigen::ThreadPoolDevice> device; +}; + LocalClientTestBase::LocalClientTestBase( perftools::gputools::Platform* platform) : local_client_( - ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) { + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), + thread_pool_wrapper_(new EigenThreadPoolWrapper()) { stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform()) .ValueOrDie()[local_client_->default_device_ordinal()]; transfer_manager_ = TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie(); } +LocalClientTestBase::~LocalClientTestBase() {} + std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) { return LiteralToScopedShapedBuffer(literal, @@ -190,8 +215,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ExecutableRunOptions run_options; run_options.set_inter_op_thread_pool( local_client_->backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - local_client_->backend().eigen_intra_op_thread_pool_device()); + run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); return run_options; } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 4e7b05cea6..e3c3bb46cf 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -74,8 +74,10 @@ class TestAllocator : public StreamExecutorMemoryAllocator { // A base class for tests which exercise the LocalClient interface. class LocalClientTestBase : public ::testing::Test { protected: + struct EigenThreadPoolWrapper; explicit LocalClientTestBase( perftools::gputools::Platform* platform = nullptr); + virtual ~LocalClientTestBase(); static TestAllocator* GetOrCreateAllocator( perftools::gputools::Platform* platform); @@ -142,6 +144,8 @@ class LocalClientTestBase : public ::testing::Test { TransferManager* transfer_manager_; LocalClient* local_client_; + + std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_; }; } // namespace xla |