aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-27 12:52:51 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitf5f67296c3430bae595697af3a78460e027cdc6d (patch)
tree06c4d7addb7820150d8620e79d41701f5db7b073 /tensorflow/core/distributed_runtime
parent15f6b62aeef8292eddd6edbc9ed15cc49774218e (diff)
Add GPUOptions::num_dev_to_dev_copy_streams to allow creation of
more than one device-to-device copy stream per GPU device. This is an experimental feature that will have no effect unless copy operations explicitly request a stream other than 0, which currently does not occur anywhere in a standard build. Eventually it may be of benefit in the presence of multiple bi-directional concurrent data copies. PiperOrigin-RevId: 202354513
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.cc11
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.h1
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc6
4 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index 5f6931e008..de6e4b4a7c 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -281,7 +281,7 @@ void BaseRemoteRendezvous::SameWorkerRecvDone(
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
recv_args.device_context, src_device, dst_device,
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
- std::move(done));
+ 0 /*dev_to_dev_stream_index*/, std::move(done));
}
bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index d4c47cab49..b9a3502131 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -65,11 +65,13 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
const string& peer_device, const string& peer_task, bool peer_is_local,
const string& key, Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
- const DeviceLocality& client_locality, const StatusCallback& done) {
+ const DeviceLocality& client_locality, int dev_to_dev_stream_index,
+ const StatusCallback& done) {
if (peer_is_local) {
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
- to_alloc_attr, to_tensor, client_locality, done);
+ to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
+ done);
return;
}
@@ -83,7 +85,8 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
// Logic to be executed on the RecvBufAsync callback.
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
- to_device_ctx, to_tensor, done](const Status& s) {
+ to_device_ctx, to_tensor, dev_to_dev_stream_index,
+ done](const Status& s) {
if (s.ok()) {
// In this generic implementation the bytes come back in the
// RPC response protobuf rather than via RDMA so we need to copy
@@ -119,7 +122,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
CopyTensor::ViaDMA("", // edge name (non-existent)
nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
to_device, cpu_attr, to_alloc_attr, cpu_tensor,
- to_tensor,
+ to_tensor, dev_to_dev_stream_index,
[this, cpu_tensor, done](const Status& s) {
delete cpu_tensor;
// This callback must not block, so execute
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index cfa9110f47..9434cacbca 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -37,6 +37,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
+ int dev_to_dev_stream_index,
const StatusCallback& done) override;
void StartAbort(const Status& s) override;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index a552f81f58..bfd312410c 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -280,7 +280,7 @@ TEST_F(CollRMADistTest, ProdFirstOK) {
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
- device_locality_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
[this, &consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@@ -309,7 +309,7 @@ TEST_F(CollRMADistTest, ConsFirstOK) {
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
- device_locality_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
[this, &consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@@ -342,7 +342,7 @@ TEST_F(CollRMADistTest, ConsFirstAbort) {
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
- device_locality_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
[this, &consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();