aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/backend.h
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2017-06-14 18:46:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-14 18:50:25 -0700
commit7d3497a639670d9c31d09185ff97b852f0fbe101 (patch)
tree8f38e3073e522477d7a5b6c270a99eee28559ec8 /tensorflow/compiler/xla/service/backend.h
parent0fa19543aa0c4fe84851cc00b0d731260e825739 (diff)
Add ComputationPlacer to assign device ids for replicated model-parallel computations.
PiperOrigin-RevId: 159056198
Diffstat (limited to 'tensorflow/compiler/xla/service/backend.h')
-rw-r--r--tensorflow/compiler/xla/service/backend.h31
1 files changed, 15 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index e0b15dc43f..3ead274cde 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -40,6 +41,8 @@ struct ThreadPoolDevice;
namespace xla {
// Options to configure the backend when it is created.
+//
+// TODO(b/62588571): Remove the notion of replicas from the backend.
class BackendOptions {
public:
// Set the platform backing the backend, or nullptr for the default platform.
@@ -92,11 +95,15 @@ class Backend {
return memory_allocator_.get();
}
TransferManager* transfer_manager() const { return transfer_manager_; }
+ ComputationPlacer* computation_placer() const { return computation_placer_; }
// Returns the number of devices of the platform type which are visible. Not
// all of these devices may be usable by XLA.
int device_count() const { return stream_executors_.size(); }
+ // Returns the number of replicas when replication is enabled.
+ int replica_count() const { return replica_count_; }
+
// Returns the device ordinal number of the default device.
int default_device_ordinal() const;
@@ -107,24 +114,13 @@ class Backend {
return stream_executors_;
}
- // Returns the replicas for the default stream executor.
- //
- // When the number of replicas is R, the first R stream executors are assigned
- // to the replicas of the default stream executor.
- std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
-
- // Returns the replicas for the given device_ordinal. The given device ordinal
- // is considered to be the first device ordinal among the replicas. Returns an
- // error status if the stream executor for the given given device ordinal does
- // not exist or if there are not enough stream executors for the replicas.
- StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
- int device_ordinal) const;
-
- // Return the stream executor for the given device ordinal.
+ // Returns the stream executor for the given device ordinal.
StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
int device_ordinal) const;
- // Return the stream executor for the default device ordinal.
+ // Returns the stream executor for the default device ordinal. This stream
+ // executor can only be used when the number of computations is 1 (replication
+ // can be > 1).
perftools::gputools::StreamExecutor* default_stream_executor() const {
CHECK(!stream_executors_.empty());
return stream_executors_[0];
@@ -178,13 +174,16 @@ class Backend {
Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
- TransferManager* transfer_manager, int intra_op_parallelism_threads);
+ TransferManager* transfer_manager,
+ ComputationPlacer* computation_placer,
+ int intra_op_parallelism_threads);
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;
perftools::gputools::Platform* platform_;
Compiler* compiler_;
TransferManager* transfer_manager_;
+ ComputationPlacer* computation_placer_;
int64 replica_count_ = -1;
// Vector of stream executors. stream_executors_[0] is the default executor.