aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/kernels/xla_local_launch_op.cc16
-rw-r--r--tensorflow/compiler/jit/kernels/xla_local_launch_op.h3
-rw-r--r--tensorflow/compiler/xla/client/client_library.cc12
-rw-r--r--tensorflow/compiler/xla/client/client_library.h5
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc11
-rw-r--r--tensorflow/compiler/xla/service/backend.cc64
-rw-r--r--tensorflow/compiler/xla/service/backend.h32
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h4
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc9
-rw-r--r--tensorflow/compiler/xla/service/service.cc25
-rw-r--r--tensorflow/compiler/xla/service/service.h5
-rw-r--r--tensorflow/compiler/xla/service/service_executable_run_options.h14
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc5
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc30
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h4
17 files changed, 210 insertions, 52 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
index 032ded54e6..a58abcbdff 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
@@ -159,13 +159,17 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
}
}
-Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** cache) {
+Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache) {
auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_);
if (!platform.ok()) {
return StreamExecutorUtil::ConvertStatus(platform.status());
}
- auto client =
- xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie());
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform.ValueOrDie());
+ client_options.set_intra_op_parallelism_threads(
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
@@ -194,8 +198,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
XlaCompilationCache* cache;
OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
- [this](XlaCompilationCache** cache) {
- return BuildCompilationCache(cache);
+ [this, ctx](XlaCompilationCache** cache) {
+ return BuildCompilationCache(ctx, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
@@ -264,8 +268,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(&xla_allocator);
- run_options.set_inter_op_thread_pool(
- ctx->device()->tensorflow_cpu_worker_threads()->workers);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
Env* env = Env::Default();
auto start_time = env->NowMicros();
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h
index 51887cf013..5e4d3336a9 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h
@@ -44,7 +44,8 @@ class XlaLocalLaunchOp : public OpKernel {
private:
// Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(XlaCompilationCache** compiler);
+ Status BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** compiler);
DeviceType device_type_;
NameAttrList function_;
diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc
index eb9a7ff2ac..8238261e1c 100644
--- a/tensorflow/compiler/xla/client/client_library.cc
+++ b/tensorflow/compiler/xla/client/client_library.cc
@@ -43,6 +43,16 @@ int LocalClientOptions::number_of_replicas() const {
return number_of_replicas_;
}
+LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
+ int num_threads) {
+ intra_op_parallelism_threads_ = num_threads;
+ return *this;
+}
+
+int LocalClientOptions::intra_op_parallelism_threads() const {
+ return intra_op_parallelism_threads_;
+}
+
/* static */ ClientLibrary& ClientLibrary::Singleton() {
static ClientLibrary* c = new ClientLibrary;
return *c;
@@ -77,6 +87,8 @@ ClientLibrary::~ClientLibrary() = default;
ServiceOptions service_options;
service_options.set_platform(platform);
service_options.set_number_of_replicas(replica_count);
+ service_options.set_intra_op_parallelism_threads(
+ options.intra_op_parallelism_threads());
auto instance = MakeUnique<LocalInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h
index 49f4541437..3ddd235d0e 100644
--- a/tensorflow/compiler/xla/client/client_library.h
+++ b/tensorflow/compiler/xla/client/client_library.h
@@ -53,9 +53,14 @@ class LocalClientOptions {
LocalClientOptions& set_number_of_replicas(int number_of_replicas);
int number_of_replicas() const;
+ // Sets the thread pool size for parallel execution of an individual operator.
+ LocalClientOptions& set_intra_op_parallelism_threads(int num_threads);
+ int intra_op_parallelism_threads() const;
+
private:
perftools::gputools::Platform* platform_ = nullptr;
int number_of_replicas_ = -1;
+ int intra_op_parallelism_threads_ = -1;
};
class ClientLibrary {
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 0de58ea7dc..02cf57e763 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -185,8 +185,15 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
if (options.allocator() == nullptr) {
actual_options.set_allocator(backend_->memory_allocator());
}
- ServiceExecutableRunOptions service_options(actual_options,
- backend_->StreamBorrower());
+
+ // For local client execution on CPU backends:
+ // *) The thread pool used for eigen CPU ops is from
+ // ExecutableRunOptions.eigen_intra_op_thread_pool.
+ // *) The thread pool used for XLA CPU ops is from
+ // backend_->eigen_intra_op_thread_pool().
+ ServiceExecutableRunOptions service_options(
+ actual_options, backend_->StreamBorrower(),
+ backend_->eigen_intra_op_thread_pool());
if (executable_->dumping()) {
return ExecuteAndDump(&service_options, arguments);
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 5c05417c6d..1913617fec 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -41,13 +41,39 @@ namespace se = ::perftools::gputools;
namespace xla {
+BackendOptions& BackendOptions::set_platform(
+ perftools::gputools::Platform* platform) {
+ platform_ = platform;
+ return *this;
+}
+
+perftools::gputools::Platform* BackendOptions::platform() const {
+ return platform_;
+}
+
+BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
+ number_of_replicas_ = number_of_replicas;
+ return *this;
+}
+
+int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
+
+BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
+ int num_threads) {
+ intra_op_parallelism_threads_ = num_threads;
+ return *this;
+}
+
+int BackendOptions::intra_op_parallelism_threads() const {
+ return intra_op_parallelism_threads_;
+}
+
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct Backend::EigenThreadPoolWrapper {
- explicit EigenThreadPoolWrapper()
- : pool(new tensorflow::thread::ThreadPool(
- tensorflow::Env::Default(), "XLAEigen",
- tensorflow::port::NumSchedulableCPUs())),
+ explicit EigenThreadPoolWrapper(const int num_threads)
+ : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
+ "XLAEigen", num_threads)),
wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
device(new Eigen::ThreadPoolDevice(wrapper.get(),
wrapper->NumThreads())) {}
@@ -58,18 +84,21 @@ struct Backend::EigenThreadPoolWrapper {
};
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
- perftools::gputools::Platform* platform, int64 replica_count) {
+ const BackendOptions& options) {
+ int64 replica_count = options.number_of_replicas();
if (replica_count == -1) {
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
replica_count = flags->xla_replicas;
}
+ perftools::gputools::Platform* platform = options.platform();
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto stream_executors,
PlatformUtil::GetStreamExecutors(platform));
TF_ASSIGN_OR_RETURN(auto transfer_manager,
TransferManager::GetForPlatform(platform));
- std::unique_ptr<Backend> backend(new Backend(
- replica_count, platform, compiler, stream_executors, transfer_manager));
+ std::unique_ptr<Backend> backend(
+ new Backend(replica_count, platform, compiler, stream_executors,
+ transfer_manager, options.intra_op_parallelism_threads()));
TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool,
backend->default_stream_executor()));
return std::move(backend);
@@ -79,7 +108,9 @@ struct Backend::EigenThreadPoolWrapper {
Backend::CreateDefaultBackend() {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetDefaultPlatform());
- return CreateBackend(platform);
+ BackendOptions backend_options;
+ backend_options.set_platform(platform);
+ return CreateBackend(backend_options);
}
tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) {
@@ -114,7 +145,7 @@ Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
- TransferManager* transfer_manager)
+ TransferManager* transfer_manager, int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
@@ -144,7 +175,11 @@ Backend::Backend(
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
tensorflow::Env::Default(), "xla_inter_op",
tensorflow::port::NumSchedulableCPUs()));
- intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper());
+ const int num_threads = intra_op_parallelism_threads > 0
+ ? intra_op_parallelism_threads
+ : tensorflow::port::NumSchedulableCPUs();
+ intra_op_thread_pool_wrapper_.reset(
+ new EigenThreadPoolWrapper(num_threads));
}
}
@@ -190,10 +225,17 @@ tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device()
const {
- if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr;
+ if (intra_op_thread_pool_wrapper_ == nullptr) {
+ return nullptr;
+ }
return intra_op_thread_pool_wrapper_->device.get();
}
+tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const {
+ if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr;
+ return intra_op_thread_pool_wrapper_->pool.get();
+}
+
StatusOr<perftools::gputools::StreamExecutor*> Backend::stream_executor(
int device_ordinal) const {
if (device_ordinal < 0 ||
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 9f6829b7d9..1068bac277 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -39,6 +39,31 @@ struct ThreadPoolDevice;
namespace xla {
+// Options to configure the backend when it is created.
+class BackendOptions {
+ public:
+ // Set the platform backing the backend, or nullptr for the default platform.
+ BackendOptions& set_platform(perftools::gputools::Platform* platform);
+ perftools::gputools::Platform* platform() const;
+
+ // Set the number of replicas to use when compiling replicated
+ // programs. The default is -1 meaning that the value is read from
+ // the xla_replicas flag.
+ BackendOptions& set_number_of_replicas(int number_of_replicas);
+ int number_of_replicas() const;
+
+ // Sets the thread pool size for parallel execution of an individual operator.
+ // The default value of -1 will result in initializing the thread pool with
+ // the number of threads equal to the number of cores in the system.
+ BackendOptions& set_intra_op_parallelism_threads(int num_threads);
+ int intra_op_parallelism_threads() const;
+
+ private:
+ perftools::gputools::Platform* platform_ = nullptr;
+ int number_of_replicas_ = -1;
+ int intra_op_parallelism_threads_ = -1;
+};
+
// Class which encapsulates an XLA backend. It includes everything necessary
// to compile and execute computations on a particular platform.
//
@@ -53,9 +78,9 @@ class Backend {
static constexpr int kInitialStreamsToPool = 8;
// Creates a new backend for the given platform with the given number of
- // replicas. A value of -1 means to use the flag value.
+ // replicas.
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
- perftools::gputools::Platform* platform, int64 replica_count = -1);
+ const BackendOptions& options);
// Creates a backend for the default platform. The default platform is defined
// in PlatformUtil.
@@ -150,6 +175,7 @@ class Backend {
// For the host platform, returns the configured eigen threadpool device to be
// used for scheduling work. For other platforms, returns NULL.
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
+ tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const;
// Resets the devices associated with this backend.
Status ResetDevices();
@@ -160,7 +186,7 @@ class Backend {
Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
- TransferManager* transfer_manager);
+ TransferManager* transfer_manager, int intra_op_parallelism_threads);
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
index 7a4723e8d7..cadad10910 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
@@ -146,7 +146,7 @@ Status ParallelCpuExecutable::AllocateBuffers(
}
Status ParallelCpuExecutable::ExecuteComputeFunctions(
- const ExecutableRunOptions* run_options,
+ const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
@@ -160,7 +160,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
}
Status ParallelCpuExecutable::ExecuteComputeFunctions(
- const ExecutableRunOptions* run_options,
+ const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
@@ -214,7 +214,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
void** temps_array = buffer_pointers.data();
uint64* profile_counters_array = profile_counters.data();
- auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool());
+ auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool());
tensorflow::mutex completion_queue_lock;
tensorflow::condition_variable completion_queue_cv;
std::deque<HloInstruction*> completion_queue;
@@ -251,11 +251,12 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
});
auto function = FindOrDie(functions, instruction);
// The thread pool entry takes ownership of |operand_buffers|.
+ const auto* exec_run_options = &run_options->run_options();
thread_pool->Schedule([instruction, &completion_queue,
&completion_queue_lock, &completion_queue_cv,
- result_buffer, run_options, operand_buffers,
+ result_buffer, exec_run_options, operand_buffers,
temps_array, profile_counters_array, function] {
- function(result_buffer, run_options, operand_buffers, temps_array,
+ function(result_buffer, exec_run_options, operand_buffers, temps_array,
profile_counters_array);
delete[] operand_buffers;
// Push the completed HLO instruction on the queue, the main thread
@@ -345,9 +346,8 @@ ParallelCpuExecutable::ExecuteOnStream(
const BufferAllocation::Index result_index = result_slice.index();
VLOG(3) << "result index: " << result_index;
- TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(),
- arguments, device_allocations,
- hlo_execution_profile));
+ TF_RETURN_IF_ERROR(ExecuteComputeFunctions(
+ run_options, arguments, device_allocations, hlo_execution_profile));
// Mark the buffers that are actually live (used in the output) when the
// computation finishes executing.
@@ -400,8 +400,8 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
- TF_RETURN_IF_ERROR(ExecuteComputeFunctions(
- &run_options->run_options(), arguments, buffers, hlo_execution_profile));
+ TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers,
+ hlo_execution_profile));
// Copy DeviceMemoryBase values which contain the array(s) of the result into
// the respective location in ShapedBuffer which is returned to the caller.
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
index 7223de9f07..6e1239d590 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
@@ -96,14 +96,14 @@ class ParallelCpuExecutable : public Executable {
// Calls the generated functions in 'function_names_', performing the
// computation with the given arguments using the supplied buffers.
Status ExecuteComputeFunctions(
- const ExecutableRunOptions* run_options,
+ const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
buffers,
HloExecutionProfile* hlo_execution_profile);
Status ExecuteComputeFunctions(
- const ExecutableRunOptions* run_options,
+ const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
buffers,
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 8184dbabc8..78d21233c7 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -60,9 +60,12 @@ namespace xla {
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
}
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Backend> backend,
- Backend::CreateBackend(platform, options.number_of_replicas()));
+ BackendOptions backend_options;
+ backend_options.set_platform(platform)
+ .set_number_of_replicas(options.number_of_replicas())
+ .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
+ Backend::CreateBackend(backend_options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index c001e705de..42450dfcae 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -112,6 +112,16 @@ ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
+ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
+ int num_threads) {
+ intra_op_parallelism_threads_ = num_threads;
+ return *this;
+}
+
+int ServiceOptions::intra_op_parallelism_threads() const {
+ return intra_op_parallelism_threads_;
+}
+
/* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
perftools::gputools::Platform* platform) {
ServiceOptions default_options;
@@ -126,9 +136,10 @@ int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
if (platform == nullptr) {
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
}
- TF_ASSIGN_OR_RETURN(
- execute_backend,
- Backend::CreateBackend(platform, options.number_of_replicas()));
+ BackendOptions backend_options;
+ backend_options.set_platform(platform);
+ backend_options.set_number_of_replicas(options.number_of_replicas());
+ TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<Service> service(new Service(
@@ -142,7 +153,10 @@ Service::CreateComputeConstantBackend() {
PlatformUtil::GetSupportedPlatforms());
for (auto* platform : platforms) {
if (platform->id() == se::host::kHostPlatformId) {
- return Backend::CreateBackend(platform, /*replica_count=*/1);
+ BackendOptions backend_options;
+ backend_options.set_platform(platform);
+ backend_options.set_number_of_replicas(1);
+ return Backend::CreateBackend(backend_options);
}
}
return NotFound("CPU platform not found");
@@ -573,7 +587,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
- run_options.emplace_back(options, backend->StreamBorrower());
+ run_options.emplace_back(options, backend->StreamBorrower(),
+ backend->inter_op_thread_pool());
}
perftools::gputools::DeviceMemoryBase result;
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 0e0e7c4e21..05a955137f 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -63,9 +63,14 @@ class ServiceOptions {
ServiceOptions& set_number_of_replicas(int number_of_replicas);
int number_of_replicas() const;
+ // Sets the thread pool size for parallel execution of an individual operator.
+ ServiceOptions& set_intra_op_parallelism_threads(int num_threads);
+ int intra_op_parallelism_threads() const;
+
private:
perftools::gputools::Platform* platform_ = nullptr;
int number_of_replicas_ = -1;
+ int intra_op_parallelism_threads_ = -1;
};
// The XLA service object, which is the same across all
diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h
index 0d4b214f5f..017e5ef09e 100644
--- a/tensorflow/compiler/xla/service/service_executable_run_options.h
+++ b/tensorflow/compiler/xla/service/service_executable_run_options.h
@@ -30,10 +30,12 @@ class ServiceExecutableRunOptions {
using StreamBorrower =
std::function<StatusOr<Pool<perftools::gputools::Stream>::SmartPtr>(int)>;
- explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options,
- StreamBorrower borrow_stream = nullptr)
+ explicit ServiceExecutableRunOptions(
+ ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr,
+ tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr)
: run_options_(std::move(run_options)),
- borrow_stream_(std::move(borrow_stream)) {}
+ borrow_stream_(std::move(borrow_stream)),
+ xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {}
// Returns reference or pointer to `ExecutableRunOptions` member.
const ExecutableRunOptions& run_options() const { return run_options_; }
@@ -53,9 +55,15 @@ class ServiceExecutableRunOptions {
: Status(tensorflow::error::UNIMPLEMENTED, "No stream cache");
}
+ // Returns reference to thread pool for execution of XLA ops on CPU backend.
+ tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const {
+ return xla_intra_op_thread_pool_;
+ }
+
private:
ExecutableRunOptions run_options_;
StreamBorrower borrow_stream_;
+ tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index d7644a0513..e0c2b9ab09 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -200,11 +200,13 @@ cc_library(
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:pool",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//third_party/eigen3",
],
)
@@ -1362,6 +1364,7 @@ cc_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/service:computation_tracker",
"//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
)
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 4e956bc00c..f741ff38b5 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -111,8 +111,9 @@ StatusOr<se::DeviceMemoryBase> HloTestBase::Execute(
backend_->eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
- ServiceExecutableRunOptions service_run_options(run_options,
- backend_->StreamBorrower());
+ ServiceExecutableRunOptions service_run_options(
+ run_options, backend_->StreamBorrower(),
+ backend_->inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index 7fe4c9020f..7fcf687655 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -17,12 +17,19 @@ limitations under the License.
#include <vector>
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -91,16 +98,34 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const {
return allocator_;
}
+// Define this in .cc file to avoid having to include eigen or forward declare
+// these types in the header.
+struct LocalClientTestBase::EigenThreadPoolWrapper {
+ explicit EigenThreadPoolWrapper()
+ : pool(new tensorflow::thread::ThreadPool(
+ tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
+ wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
+ device(new Eigen::ThreadPoolDevice(wrapper.get(),
+ wrapper->NumThreads())) {}
+
+ std::unique_ptr<tensorflow::thread::ThreadPool> pool;
+ std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
+};
+
LocalClientTestBase::LocalClientTestBase(
perftools::gputools::Platform* platform)
: local_client_(
- ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) {
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
+ thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
.ValueOrDie()[local_client_->default_device_ordinal()];
transfer_manager_ =
TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
}
+LocalClientTestBase::~LocalClientTestBase() {}
+
std::unique_ptr<ScopedShapedBuffer>
LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) {
return LiteralToScopedShapedBuffer(literal,
@@ -190,8 +215,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
ExecutableRunOptions run_options;
run_options.set_inter_op_thread_pool(
local_client_->backend().inter_op_thread_pool());
- run_options.set_intra_op_thread_pool(
- local_client_->backend().eigen_intra_op_thread_pool_device());
+ run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
return run_options;
}
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 4e7b05cea6..e3c3bb46cf 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -74,8 +74,10 @@ class TestAllocator : public StreamExecutorMemoryAllocator {
// A base class for tests which exercise the LocalClient interface.
class LocalClientTestBase : public ::testing::Test {
protected:
+ struct EigenThreadPoolWrapper;
explicit LocalClientTestBase(
perftools::gputools::Platform* platform = nullptr);
+ virtual ~LocalClientTestBase();
static TestAllocator* GetOrCreateAllocator(
perftools::gputools::Platform* platform);
@@ -142,6 +144,8 @@ class LocalClientTestBase : public ::testing::Test {
TransferManager* transfer_manager_;
LocalClient* local_client_;
+
+ std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
};
} // namespace xla