aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2017-06-14 18:46:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-14 18:50:25 -0700
commit7d3497a639670d9c31d09185ff97b852f0fbe101 (patch)
tree8f38e3073e522477d7a5b6c270a99eee28559ec8
parent0fa19543aa0c4fe84851cc00b0d731260e825739 (diff)
Add ComputationPlacer to assign device ids for replicated model-parallel computations.
PiperOrigin-RevId: 159056198
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc10
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h6
-rw-r--r--tensorflow/compiler/xla/service/BUILD21
-rw-r--r--tensorflow/compiler/xla/service/backend.cc42
-rw-r--r--tensorflow/compiler/xla/service/backend.h31
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc151
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.h113
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/service.cc206
-rw-r--r--tensorflow/compiler/xla/service/service.h13
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/xla_data.proto25
12 files changed, 489 insertions, 134 deletions
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
index 67f3a6c1df..33d5b6f1d4 100644
--- a/tensorflow/compiler/xla/executable_run_options.cc
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
return execution_profile_;
}
+ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
+ DeviceAssignment* device_assignment) {
+ device_assignment_ = device_assignment;
+ return *this;
+}
+
+DeviceAssignment* ExecutableRunOptions::device_assignment() const {
+ return device_assignment_;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
index 03f2d016ad..deb3ddb203 100644
--- a/tensorflow/compiler/xla/executable_run_options.h
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -40,6 +40,7 @@ struct ThreadPoolDevice;
namespace xla {
class DeviceMemoryAllocator;
+class DeviceAssignment;
class ExecutionProfile;
// Class containing options for running a LocalExecutable.
@@ -79,9 +80,14 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile() const;
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
+ ExecutableRunOptions& set_device_assignment(
+ DeviceAssignment* device_assignment);
+ DeviceAssignment* device_assignment() const;
+
private:
DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
+ DeviceAssignment* device_assignment_ = nullptr;
perftools::gputools::Stream* stream_ = nullptr;
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index acaf9eaafb..9a041e3f41 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -332,6 +332,7 @@ cc_library(
hdrs = ["backend.h"],
deps = [
":compiler",
+ ":computation_placer",
":device_memory_allocator",
":platform_util",
":pool",
@@ -951,6 +952,26 @@ cc_test(
)
cc_library(
+ name = "computation_placer",
+ srcs = ["computation_placer.cc"],
+ hdrs = ["computation_placer.h"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+ alwayslink = True, # Contains per-platform computation placer registration
+)
+
+cc_library(
name = "generic_transfer_manager",
srcs = ["generic_transfer_manager.cc"],
hdrs = ["generic_transfer_manager.h"],
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 66d54ad380..83f954789c 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -96,9 +96,11 @@ struct Backend::EigenThreadPoolWrapper {
PlatformUtil::GetStreamExecutors(platform));
TF_ASSIGN_OR_RETURN(auto transfer_manager,
TransferManager::GetForPlatform(platform));
- std::unique_ptr<Backend> backend(
- new Backend(replica_count, platform, compiler, stream_executors,
- transfer_manager, options.intra_op_parallelism_threads()));
+ TF_ASSIGN_OR_RETURN(auto computation_placer,
+ ComputationPlacer::GetForPlatform(platform));
+ std::unique_ptr<Backend> backend(new Backend(
+ replica_count, platform, compiler, stream_executors, transfer_manager,
+ computation_placer, options.intra_op_parallelism_threads()));
return std::move(backend);
}
@@ -135,10 +137,12 @@ Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
- TransferManager* transfer_manager, int intra_op_parallelism_threads)
+ TransferManager* transfer_manager, ComputationPlacer* computation_placer,
+ int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
+ computation_placer_(computation_placer),
replica_count_(replica_count) {
// The given set of stream executors set may include invalid executors.
for (se::StreamExecutor* exec : stream_executors) {
@@ -179,36 +183,6 @@ int Backend::default_device_ordinal() const {
return default_stream_executor()->device_ordinal();
}
-StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
- int device_ordinal) const {
- if (stream_executors_[device_ordinal] == nullptr) {
- return InvalidArgument("device %s not supported by XLA service",
- device_name(device_ordinal).c_str());
- }
-
- // Find replica_count_ stream executors starting from the given device
- // ordinal.
- std::vector<perftools::gputools::StreamExecutor*> replicas;
- for (se::StreamExecutor* exec : stream_executors_) {
- CHECK(exec != nullptr);
- if (exec->device_ordinal() >= device_ordinal) {
- replicas.push_back(exec);
- if (replicas.size() >= replica_count_) {
- return replicas;
- }
- }
- }
-
- return InvalidArgument(
- "Not enough devices for replicas for the device ordinal %d",
- device_ordinal);
-}
-
-std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
- CHECK_GE(stream_executors_.size(), replica_count_);
- return Replicas(default_device_ordinal()).ValueOrDie();
-}
-
tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
return inter_op_thread_pool_.get();
}
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index e0b15dc43f..3ead274cde 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -40,6 +41,8 @@ struct ThreadPoolDevice;
namespace xla {
// Options to configure the backend when it is created.
+//
+// TODO(b/62588571): Remove the notion of replicas from the backend.
class BackendOptions {
public:
// Set the platform backing the backend, or nullptr for the default platform.
@@ -92,11 +95,15 @@ class Backend {
return memory_allocator_.get();
}
TransferManager* transfer_manager() const { return transfer_manager_; }
+ ComputationPlacer* computation_placer() const { return computation_placer_; }
// Returns the number of devices of the platform type which are visible. Not
// all of these devices may be usable by XLA.
int device_count() const { return stream_executors_.size(); }
+ // Returns the number of replicas when replication is enabled.
+ int replica_count() const { return replica_count_; }
+
// Returns the device ordinal number of the default device.
int default_device_ordinal() const;
@@ -107,24 +114,13 @@ class Backend {
return stream_executors_;
}
- // Returns the replicas for the default stream executor.
- //
- // When the number of replicas is R, the first R stream executors are assigned
- // to the replicas of the default stream executor.
- std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
-
- // Returns the replicas for the given device_ordinal. The given device ordinal
- // is considered to be the first device ordinal among the replicas. Returns an
- // error status if the stream executor for the given given device ordinal does
- // not exist or if there are not enough stream executors for the replicas.
- StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
- int device_ordinal) const;
-
- // Return the stream executor for the given device ordinal.
+ // Returns the stream executor for the given device ordinal.
StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
int device_ordinal) const;
- // Return the stream executor for the default device ordinal.
+ // Returns the stream executor for the default device ordinal. This stream
+ // executor can only be used when the number of computations is 1 (replication
+ // can be > 1).
perftools::gputools::StreamExecutor* default_stream_executor() const {
CHECK(!stream_executors_.empty());
return stream_executors_[0];
@@ -178,13 +174,16 @@ class Backend {
Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
- TransferManager* transfer_manager, int intra_op_parallelism_threads);
+ TransferManager* transfer_manager,
+ ComputationPlacer* computation_placer,
+ int intra_op_parallelism_threads);
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;
perftools::gputools::Platform* platform_;
Compiler* compiler_;
TransferManager* transfer_manager_;
+ ComputationPlacer* computation_placer_;
int64 replica_count_ = -1;
// Vector of stream executors. stream_executors_[0] is the default executor.
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
new file mode 100644
index 0000000000..cdf277581f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
+ proto->set_replica_count(replica_count());
+ proto->set_computation_count(computation_count());
+ for (int computation = 0; computation < computation_count(); ++computation) {
+ DeviceAssignmentProto::ComputationDevice* computation_device =
+ proto->add_computation_devices();
+ for (int replica = 0; replica < replica_count(); ++replica) {
+ computation_device->add_replica_device_ids((*this)(replica, computation));
+ }
+ }
+ return Status::OK();
+}
+
+/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize(
+ const DeviceAssignmentProto& proto) {
+ TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
+ DeviceAssignment assignment(proto.replica_count(), proto.computation_count());
+ for (int computation = 0; computation < proto.computation_count();
+ ++computation) {
+ const auto& computation_device = proto.computation_devices(computation);
+ TF_RET_CHECK(computation_device.replica_device_ids_size() ==
+ proto.replica_count());
+ for (int replica = 0; replica < proto.replica_count(); ++replica) {
+ assignment(replica, computation) =
+ computation_device.replica_device_ids(replica);
+ }
+ }
+ return std::move(assignment);
+}
+
+StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
+ int replica_count,
+ int computation_count) {
+ TF_RET_CHECK(replica < replica_count);
+ TF_RET_CHECK(computation < computation_count);
+
+ return computation * replica_count + replica;
+}
+
+StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
+ int replica_count, int computation_count) {
+ DeviceAssignment assignment(replica_count, computation_count);
+ for (int replica = 0; replica < replica_count; ++replica) {
+ for (int computation = 0; computation < computation_count; ++computation) {
+ TF_ASSIGN_OR_RETURN(
+ int device_id,
+ DeviceId(replica, computation, replica_count, computation_count));
+ assignment(replica, computation) = device_id;
+ }
+ }
+ return std::move(assignment);
+}
+
+/* static */ void ComputationPlacer::RegisterComputationPlacer(
+ se::Platform::Id platform_id,
+ ComputationPlacerCreationFunction creation_function) {
+ tensorflow::mutex_lock lock(
+ *ComputationPlacer::platform_computation_placer_mutex());
+ auto* computation_placers = GetPlatformComputationPlacers();
+ CHECK(computation_placers->find(platform_id) == computation_placers->end());
+ (*computation_placers)[platform_id].creation_function = creation_function;
+}
+
+/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
+ const se::Platform* platform) {
+ tensorflow::mutex_lock lock(
+ *ComputationPlacer::platform_computation_placer_mutex());
+ auto* computation_placers = GetPlatformComputationPlacers();
+
+ auto it = computation_placers->find(platform->id());
+ if (it == computation_placers->end()) {
+ return NotFound(
+ "could not find registered computation placer for platform %s -- check "
+ "target linkage",
+ platform->Name().c_str());
+ }
+
+ if (it->second.placer == nullptr) {
+ // Lazily create the computation placer the first time it is needed.
+ it->second.placer = (*it->second.creation_function)();
+ }
+
+ return it->second.placer.get();
+}
+
+/* static */ tensorflow::mutex*
+ComputationPlacer::platform_computation_placer_mutex() {
+ static tensorflow::mutex* m = new tensorflow::mutex;
+ return m;
+}
+
+/* static */ std::map<perftools::gputools::Platform::Id,
+ ComputationPlacer::State>*
+ComputationPlacer::GetPlatformComputationPlacers() {
+ static auto* r =
+ new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>;
+ return r;
+}
+
+} // namespace xla
+
+static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
+ return xla::MakeUnique<xla::ComputationPlacer>();
+}
+
+static bool InitModule() {
+ xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId,
+ &CreateComputationPlacer);
+ xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId,
+ &CreateComputationPlacer);
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h
new file mode 100644
index 0000000000..4d26d6bb85
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_placer.h
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Class that represents the device assignment for a set of XLA replicated
+// computations. For R replicas and C computations, R * C devices are required
+// execute the computation in parallel. The assigned device ids can be accessed
+// by assignment(replica, computation).
+class DeviceAssignment : public Array2D<int> {
+ public:
+ DeviceAssignment() {}
+ DeviceAssignment(int replica_count, int computation_count)
+ : Array2D<int>(replica_count, computation_count, -1) {
+ CHECK_GT(replica_count, 0);
+ CHECK_GT(computation_count, 0);
+ }
+
+ int replica_count() const { return height(); }
+ int computation_count() const { return width(); }
+
+ // Protocol buffer serialization and deserialization.
+ Status Serialize(DeviceAssignmentProto* proto) const;
+ static StatusOr<DeviceAssignment> Deserialize(
+ const DeviceAssignmentProto& proto);
+};
+
+// A generic implementation of the XLA computation placer, which assigns device
+// ids to a set of replicated computations.
+class ComputationPlacer {
+ public:
+ ComputationPlacer() {}
+ virtual ~ComputationPlacer() {}
+
+ // Returns the device id assigned to the given replica and computation
+ // instance for [replica_count x computation_count] setup. The returned device
+ // id must match the assignement from PlaceReplicatedComputation().
+ virtual StatusOr<int> DeviceId(int replica, int computation,
+ int replica_count, int computation_count);
+
+ // Returns the device ids assigned to a set of replicated computations, given
+ // the number of replicas and the number of computations.
+ virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count,
+ int computation_count);
+
+ using ComputationPlacerCreationFunction =
+ std::unique_ptr<ComputationPlacer> (*)();
+
+ // Registers a computation placer creation function for a particular platform.
+ static void RegisterComputationPlacer(
+ perftools::gputools::Platform::Id platform_id,
+ ComputationPlacerCreationFunction creation_function);
+
+ // Returns the computation placer singleton pointer if it is available for the
+ // given platform, or an error status if it is not.
+ static StatusOr<ComputationPlacer*> GetForPlatform(
+ const perftools::gputools::Platform* platform);
+
+ private:
+ // Routine that returns the mutex that guards the platform-to-computation
+ // placer map. Done as a routine to ensure correct initialization ordering,
+ // since RegisterComputationPlacer can be called during program initialization
+ // time.
+ static tensorflow::mutex* platform_computation_placer_mutex();
+
+ // State kept for each kind of ComputationPlacer. Registration functions set
+ // up creation_function, and then we use that to lazily create "placer" the
+ // first time GetForPlatform is invoked for a particular id.
+ struct State {
+ std::unique_ptr<ComputationPlacer> placer;
+ ComputationPlacerCreationFunction creation_function = nullptr;
+ };
+
+ // Map from platform kind to computation placer singleton.
+ static std::map<perftools::gputools::Platform::Id, State>*
+ GetPlatformComputationPlacers();
+
+ perftools::gputools::Platform::Id platform_id_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 131c2ee87b..748124ebc7 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
module_config->set_has_hybrid_result(has_hybrid_result);
- module_config->set_replica_count(execute_backend_->Replicas().size());
+ module_config->set_replica_count(execute_backend_->replica_count());
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
if (flags->xla_hlo_profile) {
module_config->enable_hlo_profiling(true);
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 79bf679c5f..5812d3e487 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -325,7 +325,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
module_config->enable_hlo_profiling(true);
}
- module_config->set_replica_count(backend->Replicas().size());
+ module_config->set_replica_count(backend->replica_count());
module_config->set_seed(execution_options.seed());
module_config->set_debug_options(execution_options.debug_options());
@@ -495,47 +495,55 @@ Service::ExecuteParallelAndRegisterResult(
tensorflow::gtl::ArraySlice<
std::vector<perftools::gputools::DeviceMemoryBase>>
arguments,
- Backend* backend,
- tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors,
+ Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
tensorflow::gtl::ArraySlice<string> result_tags) {
- // TODO(b/33943292): Support for replication when using multiple computations.
- TF_RET_CHECK(backend->Replicas().size() == 1);
-
- // Set up streams.
+ // Streams where the computation are launched, so we can wait on the streams
+ // to complete.
std::vector<Pool<se::Stream>::SmartPtr> streams;
- for (se::StreamExecutor* executor : executors) {
- TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
- backend->BorrowStream(executor));
- streams.push_back(std::move(stream));
- }
-
- // Set up run options.
- std::vector<ServiceExecutableRunOptions> run_options;
- for (const Pool<se::Stream>::SmartPtr& stream : streams) {
- ExecutableRunOptions options;
- options.set_stream(stream.get());
- options.set_allocator(backend->memory_allocator());
- options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
- options.set_intra_op_thread_pool(
- backend->eigen_intra_op_thread_pool_device());
- run_options.emplace_back(options, backend->StreamBorrower());
- }
-
- // Asynchronously launch all executables.
+ // Global data handles for the computation results, one for each computation.
std::vector<GlobalDataHandle> result_handles;
- for (tensorflow::gtl::ArraySlice<Executable*>::size_type i = 0;
- i < executables.size(); i++) {
- TF_ASSIGN_OR_RETURN(
- perftools::gputools::DeviceMemoryBase result,
- executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i]));
- result_handles.push_back(allocation_tracker_.Register(
- backend, executors[i]->device_ordinal(), result,
- executables[i]->result_shape(), result_tags[i]));
+
+ TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
+ backend->computation_placer()->AssignDevices(
+ backend->replica_count(), executables.size()));
+
+ for (int64 i = 0; i < executables.size(); i++) {
+ // Stream executors for the replicas of the current computation.
+ TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
+ for (int64 replica = 0; replica < replicas.size(); ++replica) {
+ TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
+ backend->BorrowStream(replicas[replica]));
+ streams.push_back(std::move(stream));
+
+ // Set up run options.
+ ExecutableRunOptions options;
+ options.set_stream(streams.back().get());
+ options.set_allocator(backend->memory_allocator());
+ options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
+ options.set_intra_op_thread_pool(
+ backend->eigen_intra_op_thread_pool_device());
+ options.set_device_assignment(&device_assignment);
+ ServiceExecutableRunOptions run_options(options,
+ backend->StreamBorrower());
+
+ // Asynchronously launch the computation.
+ TF_ASSIGN_OR_RETURN(
+ perftools::gputools::DeviceMemoryBase result,
+ executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i]));
+
+ // All replicas share the same device address for the result allocation,
+ // so only one of the replicas need to register the result handle.
+ if (replica == 0) {
+ result_handles.push_back(allocation_tracker_.Register(
+ backend, replicas[0]->device_ordinal(), result,
+ executables[i]->result_shape(), result_tags[i]));
+ }
+ }
}
// Wait for all executions to complete.
- for (int64 i = 0; i < result_handles.size(); ++i) {
+ for (int64 i = 0; i < streams.size(); ++i) {
if (!streams[i]->BlockHostUntilDone()) {
return InternalError("failed to complete execution for stream %lld", i);
}
@@ -550,17 +558,22 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
arguments,
Backend* backend, perftools::gputools::StreamExecutor* executor,
const string& result_tag, ExecutionProfile* profile) {
- TF_RET_CHECK(!backend->Replicas().empty());
-
// Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams;
- for (se::StreamExecutor* executor : backend->Replicas()) {
+ TF_ASSIGN_OR_RETURN(auto replicas,
+ Replicas(*backend, SingleComputationDeviceHandle()));
+ TF_RET_CHECK(!replicas.empty());
+ for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
+ TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
+ backend->computation_placer()->AssignDevices(
+ backend->replica_count(), /*computation_count=*/1));
+
// Set up run options.
std::vector<ServiceExecutableRunOptions> run_options;
for (const Pool<se::Stream>::SmartPtr& stream : streams) {
@@ -570,19 +583,20 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
+ options.set_device_assignment(&device_assignment);
run_options.emplace_back(options, backend->StreamBorrower(),
backend->inter_op_thread_pool());
}
perftools::gputools::DeviceMemoryBase result;
- if (backend->Replicas().size() == 1) {
+ if (backend->replica_count() == 1) {
TF_ASSIGN_OR_RETURN(
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
&run_options[0], profile, arguments));
} else {
std::vector<
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
- repeated_arguments(backend->Replicas().size(), arguments);
+ repeated_arguments(backend->replica_count(), arguments);
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
run_options, repeated_arguments));
@@ -610,25 +624,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
std::vector<VersionedComputationHandle> versioned_handles;
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
std::vector<string> computation_names;
+ std::vector<DeviceHandle> device_handles;
- if (arg->requests_size() > execute_backend_->stream_executors().size()) {
+ if (arg->requests_size() * execute_backend_->replica_count() >
+ execute_backend_->device_count()) {
return FailedPrecondition(
"there are not enough stream executors to execute %d computations",
arg->requests_size());
}
for (int64 i = 0; i < arg->requests_size(); ++i) {
- // Get the stream executor on which the computation will run. Select the
- // specific device if requested, otherwise select the i'th device from the
- // list of available stream executors.
- se::StreamExecutor* executor;
- if (arg->requests(i).has_device_handle()) {
- executor =
- execute_backend_
- ->stream_executors()[arg->requests(i).device_handle().handle()];
- } else {
- executor = execute_backend_->stream_executors()[i];
+ // Get the stream executor for the i'th computation. This stream executor
+ // is one of the executors to run the replicated computation.
+ if (!arg->requests(i).has_device_handle()) {
+ return FailedPrecondition(
+ "device handles must be given to execute parallel computations");
}
+ TF_ASSIGN_OR_RETURN(
+ auto replicas,
+ Replicas(*execute_backend_, arg->requests(i).device_handle()));
+ se::StreamExecutor* executor = replicas[0];
CHECK(executor != nullptr);
// Resolve the UserComputation object associated with the requested
@@ -673,6 +688,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
module_configs.push_back(std::move(module_config));
computation_names.push_back(user_computation->name());
executors.push_back(executor);
+ device_handles.push_back(arg->requests(i).device_handle());
}
// Build the user computations into HloModules and compile to generate the
@@ -692,7 +708,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
TF_ASSIGN_OR_RETURN(
std::vector<GlobalDataHandle> outputs,
ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
- execute_backend_.get(), executors,
+ execute_backend_.get(), device_handles,
computation_names));
for (const GlobalDataHandle& output : outputs) {
ExecuteResponse response;
@@ -706,10 +722,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) {
- const int64 available_device_count =
- execute_backend_->stream_executors().size();
- const int64 replicas = execute_backend_->Replicas().size();
- if (available_device_count < arg->device_count() * replicas) {
+ const int64 available_device_count = execute_backend_->device_count();
+ const int64 replica_count = execute_backend_->replica_count();
+ if (replica_count <= 0) {
+ return FailedPrecondition("Replica count must be a positive integer");
+ }
+ if (available_device_count < arg->device_count() * replica_count) {
return ResourceExhausted(
"Requested device count (%lld) exceeds the number of available devices "
"on the target (%lld)",
@@ -718,8 +736,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
for (int64 i = 0; i < arg->device_count(); ++i) {
DeviceHandle device_handle;
- device_handle.set_handle(
- execute_backend_->stream_executors()[i * replicas]->device_ordinal());
+ device_handle.set_handle(i);
+ device_handle.set_device_count(arg->device_count());
*result->add_device_handles() = device_handle;
}
@@ -841,11 +859,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
execute_backend_->default_stream_executor(),
&profile));
- TF_RET_CHECK(!execute_backend_->Replicas().empty());
+ TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
+ SingleComputationDeviceHandle()));
+ TF_RET_CHECK(!replicas.empty());
+
// Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams;
- for (se::StreamExecutor* executor : execute_backend_->Replicas()) {
+ for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
execute_backend_->BorrowStream(executor));
streams.push_back(std::move(stream));
@@ -927,19 +948,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape();
- if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
+ if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
// TODO(b/32990684): Tuple transfers to host end up allocating further
// buffers - implement that correctly.
return Unimplemented(
"Tuple transfers to the device not supported with replication.");
}
- se::StreamExecutor* stream_executor;
+ std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
- TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
- arg->device_handle().handle()));
+ TF_ASSIGN_OR_RETURN(replicas,
+ Replicas(*execute_backend_, arg->device_handle()));
} else {
- stream_executor = execute_backend_->default_stream_executor();
+ TF_ASSIGN_OR_RETURN(
+ replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
}
// Allocate memory on the device, using the stream executor. The size of the
@@ -950,14 +972,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
execute_backend_->memory_allocator()->Allocate(
- stream_executor->device_ordinal(), allocation_size));
+ replicas[0]->device_ordinal(), allocation_size));
*result->mutable_data() = allocation_tracker_.Register(
- execute_backend_.get(), stream_executor->device_ordinal(), allocation,
- shape, StrCat("TransferToServer literal of size ", allocation_size));
+ execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape,
+ StrCat("TransferToServer literal of size ", allocation_size));
- TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
- stream_executor->device_ordinal()));
for (se::StreamExecutor* executor : replicas) {
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
@@ -968,7 +988,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) {
- const int64 replica_count = execute_backend_->Replicas().size();
+ const int64 replica_count = execute_backend_->replica_count();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
"%s",
@@ -980,11 +1000,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
se::StreamExecutor* executor;
if (arg->has_device_handle()) {
- TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
- arg->device_handle().handle()));
+ TF_ASSIGN_OR_RETURN(auto replicas,
+ Replicas(*execute_backend_, arg->device_handle()));
executor = replicas[arg->replica_id()];
} else {
- executor = execute_backend_->Replicas()[arg->replica_id()];
+ TF_ASSIGN_OR_RETURN(
+ auto replicas,
+ Replicas(*execute_backend_, SingleComputationDeviceHandle()));
+ executor = replicas[arg->replica_id()];
}
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
@@ -994,7 +1017,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::TransferFromOutfeed(
const TransferFromOutfeedRequest* arg,
TransferFromOutfeedResponse* result) {
- const int64 replica_count = execute_backend_->Replicas().size();
+ const int64 replica_count = execute_backend_->replica_count();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
@@ -1004,11 +1027,14 @@ tensorflow::Status Service::TransferFromOutfeed(
se::StreamExecutor* executor;
if (arg->has_device_handle()) {
- TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
- arg->device_handle().handle()));
+ TF_ASSIGN_OR_RETURN(auto replicas,
+ Replicas(*execute_backend_, arg->device_handle()));
executor = replicas[arg->replica_id()];
} else {
- executor = execute_backend_->Replicas()[arg->replica_id()];
+ TF_ASSIGN_OR_RETURN(
+ auto replicas,
+ Replicas(*execute_backend_, SingleComputationDeviceHandle()));
+ executor = replicas[arg->replica_id()];
}
Literal literal;
@@ -1387,4 +1413,28 @@ tensorflow::Status Service::LoadComputationSnapshot(
return tensorflow::Status::OK();
}
+DeviceHandle Service::SingleComputationDeviceHandle() const {
+ DeviceHandle device_handle;
+ device_handle.set_handle(0);
+ device_handle.set_device_count(1);
+ return device_handle;
+}
+
+StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
+ const Backend& backend, const DeviceHandle& device_handle) const {
+ std::vector<perftools::gputools::StreamExecutor*> replicas;
+ for (int replica = 0; replica < backend.replica_count(); ++replica) {
+ // From the computation placer, find out the device ids of the replicas for
+ // the given device handle.
+ TF_ASSIGN_OR_RETURN(
+ int device_ordinal,
+ backend.computation_placer()->DeviceId(replica, device_handle.handle(),
+ backend.replica_count(),
+ device_handle.device_count()));
+ TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
+ replicas.push_back(executor);
+ }
+ return replicas;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index abd1281bdd..81aa0d2e8e 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -319,8 +319,7 @@ class Service : public ServiceInterface {
std::vector<perftools::gputools::DeviceMemoryBase>>
arguments,
Backend* backend,
- tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
- executors,
+ tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
tensorflow::gtl::ArraySlice<string> result_tags);
// Returns an HLO dumper for use in the compiler (it refers to flags
@@ -346,6 +345,16 @@ class Service : public ServiceInterface {
tensorflow::Status ValidateResultShapeWithLayout(
const Shape& shape_with_layout, const Shape& result_shape) const;
+ // Returns the stream executors assigned to the replicas represented by the
+ // given device handle. Each device_handle is a virtual replicated device that
+ // represents a set of physical devices for the replicas.
+ StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
+ const Backend& backend, const DeviceHandle& device_handle) const;
+
+ // Returns the device handle that represents the replicated device for a
+ // single computation that is not model-parallelized.
+ DeviceHandle SingleComputationDeviceHandle() const;
+
// Tracks computations built via the API.
ComputationTracker computation_tracker_;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5fa515b26f..b81116c100 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -99,6 +99,7 @@ cc_library(
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
@@ -196,6 +197,7 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
@@ -746,6 +748,7 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 633d16c4c3..c8fd31d0ad 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -255,11 +255,15 @@ message ComputationDataHandle {
int64 handle = 1;
}
-// Handle given to a user that represents a device to execute a computation.
-// When replication is enabled, the device handle represents the device for the
-// replica id 0.
+// Handle given to a user that represents a replicated virtual device. Each
+// replicated device represents N physical devices for execution where N is the
+// number of replicas.
message DeviceHandle {
int64 handle = 1;
+
+ // The number of model-parallel virtual devices that communicate via XLA
+ // Send/Recv instructions.
+ int64 device_count = 2;
}
// Handle given to a user to represent a channel between two computations
@@ -269,6 +273,21 @@ message ChannelHandle {
int64 handle = 1;
}
+// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
+// represents the device ids assigned to a set of replicated computations.
+// See xla::DeviceAssignment class comment for more details.
+message DeviceAssignmentProto {
+ int32 replica_count = 1;
+ int32 computation_count = 2;
+
+ // Each logical computation runs on replica_count physical devices.
+ // ComputationDevice represents the device ids assinged to the replicas.
+ message ComputationDevice {
+ repeated int32 replica_device_ids = 1;
+ }
+ repeated ComputationDevice computation_devices = 3;
+}
+
// Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in
// computations.