aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/backend.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-12 11:37:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-12 11:41:14 -0700
commita20ebced22db1be959cdc9875f1a797fd3367712 (patch)
tree620887d2c08f5d52d326e418c466b682ac70c99c /tensorflow/compiler/xla/service/backend.h
parent6ecf9d143f98c318f008199d9cb6da00b483cf45 (diff)
[XLA:CPU] Prep work for thread-parallel XLA CPU backend.
*) Plumbs intra op thread parallelism value through to XLA backend. *) Service execution uses inter/intra op pools from backend. *) LocalService execution uses intra op pool from backend for XLA parallel ops, and intra op pool passed in ExecutableRunOptions for eigen ops. PiperOrigin-RevId: 155891730
Diffstat (limited to 'tensorflow/compiler/xla/service/backend.h')
-rw-r--r--tensorflow/compiler/xla/service/backend.h32
1 files changed, 29 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 9f6829b7d9..1068bac277 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -39,6 +39,31 @@ struct ThreadPoolDevice;
namespace xla {
+// Options to configure the backend when it is created.
+class BackendOptions {
+ public:
+ // Set the platform backing the backend, or nullptr for the default platform.
+ BackendOptions& set_platform(perftools::gputools::Platform* platform);
+ perftools::gputools::Platform* platform() const;
+
+ // Set the number of replicas to use when compiling replicated
+ // programs. The default is -1 meaning that the value is read from
+ // the xla_replicas flag.
+ BackendOptions& set_number_of_replicas(int number_of_replicas);
+ int number_of_replicas() const;
+
+ // Sets the thread pool size for parallel execution of an individual operator.
+ // The default value of -1 will result in initializing the thread pool with
+ // the number of threads equal to the number of cores in the system.
+ BackendOptions& set_intra_op_parallelism_threads(int num_threads);
+ int intra_op_parallelism_threads() const;
+
+ private:
+ perftools::gputools::Platform* platform_ = nullptr;
+ int number_of_replicas_ = -1;
+ int intra_op_parallelism_threads_ = -1;
+};
+
// Class which encapsulates an XLA backend. It includes everything necessary
// to compile and execute computations on a particular platform.
//
@@ -53,9 +78,9 @@ class Backend {
static constexpr int kInitialStreamsToPool = 8;
// Creates a new backend for the given platform with the given number of
- // replicas. A value of -1 means to use the flag value.
+ // replicas.
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
- perftools::gputools::Platform* platform, int64 replica_count = -1);
+ const BackendOptions& options);
// Creates a backend for the default platform. The default platform is defined
// in PlatformUtil.
@@ -150,6 +175,7 @@ class Backend {
// For the host platform, returns the configured eigen threadpool device to be
// used for scheduling work. For other platforms, returns NULL.
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
+ tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const;
// Resets the devices associated with this backend.
Status ResetDevices();
@@ -160,7 +186,7 @@ class Backend {
Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
- TransferManager* transfer_manager);
+ TransferManager* transfer_manager, int intra_op_parallelism_threads);
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;