diff options
author | HyoukJoong Lee <hyouklee@google.com> | 2017-06-14 18:46:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-14 18:50:25 -0700 |
commit | 7d3497a639670d9c31d09185ff97b852f0fbe101 (patch) | |
tree | 8f38e3073e522477d7a5b6c270a99eee28559ec8 /tensorflow/compiler/xla/service/backend.h | |
parent | 0fa19543aa0c4fe84851cc00b0d731260e825739 (diff) |
Add ComputationPlacer to assign device ids for replicated model-parallel computations.
PiperOrigin-RevId: 159056198
Diffstat (limited to 'tensorflow/compiler/xla/service/backend.h')
-rw-r--r-- | tensorflow/compiler/xla/service/backend.h | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index e0b15dc43f..3ead274cde 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -40,6 +41,8 @@ struct ThreadPoolDevice; namespace xla { // Options to configure the backend when it is created. +// +// TODO(b/62588571): Remove the notion of replicas from the backend. class BackendOptions { public: // Set the platform backing the backend, or nullptr for the default platform. @@ -92,11 +95,15 @@ class Backend { return memory_allocator_.get(); } TransferManager* transfer_manager() const { return transfer_manager_; } + ComputationPlacer* computation_placer() const { return computation_placer_; } // Returns the number of devices of the platform type which are visible. Not // all of these devices may be usable by XLA. int device_count() const { return stream_executors_.size(); } + // Returns the number of replicas when replication is enabled. + int replica_count() const { return replica_count_; } + // Returns the device ordinal number of the default device. int default_device_ordinal() const; @@ -107,24 +114,13 @@ class Backend { return stream_executors_; } - // Returns the replicas for the default stream executor. - // - // When the number of replicas is R, the first R stream executors are assigned - // to the replicas of the default stream executor. - std::vector<perftools::gputools::StreamExecutor*> Replicas() const; - - // Returns the replicas for the given device_ordinal. The given device ordinal - // is considered to be the first device ordinal among the replicas. Returns an - // error status if the stream executor for the given given device ordinal does - // not exist or if there are not enough stream executors for the replicas. - StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas( - int device_ordinal) const; - - // Return the stream executor for the given device ordinal. + // Returns the stream executor for the given device ordinal. StatusOr<perftools::gputools::StreamExecutor*> stream_executor( int device_ordinal) const; - // Return the stream executor for the default device ordinal. + // Returns the stream executor for the default device ordinal. This stream + // executor can only be used when the number of computations is 1 (replication + // can be > 1). perftools::gputools::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; @@ -178,13 +174,16 @@ class Backend { Compiler* compiler, tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads); + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; perftools::gputools::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; + ComputationPlacer* computation_placer_; int64 replica_count_ = -1; // Vector of stream executors. stream_executors_[0] is the default executor. |