aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 10:43:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:33:54 -0800
commit4463d105a8a4a83642b9709ba79310e8f4ddf577 (patch)
tree240e9a0a9a6b9ad956c704776a33126ba00cbfe8 /tensorflow/contrib/mpi
parent8f0e7207774279f4fe50f4d6c4fbd576e2941463 (diff)
Cleanup: Ran clang-format on all *.{cc,h} files in tensorflow/contrib/.../*.{hh,c}.
PiperOrigin-RevId: 183855242
Diffstat (limited to 'tensorflow/contrib/mpi')
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc218
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h11
-rw-r--r--tensorflow/contrib/mpi/mpi_server_lib.cc2
-rw-r--r--tensorflow/contrib/mpi/mpi_utils.h2
4 files changed, 119 insertions, 114 deletions
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index 8d14a3ef04..c2c42b8ed7 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -24,11 +24,11 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/distributed_runtime/tensor_coding.h"
namespace tensorflow {
@@ -62,7 +62,6 @@ BaseRemoteRendezvous* MPIRendezvousMgr::Create(int64 step_id,
void MPIRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
-
Status s = Status::OK();
MPIRequestTensorCall* rendezvous_call = new MPIRequestTensorCall();
@@ -103,37 +102,37 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync(
// Create the function which is called when the Tensor is send by remote
const int64 temp1 = step_id_;
rendezvous_call->recv_call_ =
- [this, parsed, recv_args, done, dst, temp1, rendezvous_call](
- MPIRecvTensorResponse mpi_response) {
- Status s;
- Device* dst_device;
- if (s.ok()) {
- s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
- CHECK(s.ok()) << "Device lookup failed";
- }
-
- VLOG(3) << "MPI Received tensor " << parsed.FullKey()
- << " @ step: " << temp1
- << " single-send: " << mpi_response.singlesend();
-
- Tensor val;
- if (mpi_response.singlesend()) {
- dst_device->MakeTensorFromProto(mpi_response.response().tensor(),
- recv_args.alloc_attrs, &val);
- } else {
- TensorResponse tr;
- tr.InitAlloc(dst_device, recv_args.alloc_attrs);
- tr.InitPartial(mpi_response.response());
- const size_t nBytes = tr.tensor().TotalBytes();
- void* data = const_cast<void*>(DMAHelper::base(&tr.tensor()));
- MPI_Status status;
- MPI_CHECK(MPI_Recv(data, static_cast<int>(nBytes), MPI_BYTE, dst,
- TAG_SENDTENSOR2, MPI_COMM_WORLD, &status));
- val = std::move(tr.tensor());
- }
-
- done(s, Args(), recv_args, val, mpi_response.response().is_dead());
- };
+ [this, parsed, recv_args, done, dst, temp1,
+ rendezvous_call](MPIRecvTensorResponse mpi_response) {
+ Status s;
+ Device* dst_device;
+ if (s.ok()) {
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ CHECK(s.ok()) << "Device lookup failed";
+ }
+
+ VLOG(3) << "MPI Received tensor " << parsed.FullKey()
+ << " @ step: " << temp1
+ << " single-send: " << mpi_response.singlesend();
+
+ Tensor val;
+ if (mpi_response.singlesend()) {
+ dst_device->MakeTensorFromProto(mpi_response.response().tensor(),
+ recv_args.alloc_attrs, &val);
+ } else {
+ TensorResponse tr;
+ tr.InitAlloc(dst_device, recv_args.alloc_attrs);
+ tr.InitPartial(mpi_response.response());
+ const size_t nBytes = tr.tensor().TotalBytes();
+ void* data = const_cast<void*>(DMAHelper::base(&tr.tensor()));
+ MPI_Status status;
+ MPI_CHECK(MPI_Recv(data, static_cast<int>(nBytes), MPI_BYTE, dst,
+ TAG_SENDTENSOR2, MPI_COMM_WORLD, &status));
+ val = std::move(tr.tensor());
+ }
+
+ done(s, Args(), recv_args, val, mpi_response.response().is_dead());
+ };
MPIRendezvousMgr* mgr =
reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
@@ -159,9 +158,11 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
MPIRecvTensorCallBack send_cb = [this, mpi_dst, parsed](
- const Status& status, const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead,
- MPISendTensorCall* mpi_send_call) {
+ const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& val, bool is_dead,
+ MPISendTensorCall* mpi_send_call) {
// TODO(jbedorf) this should be a loop over max size
CHECK(mpi_send_call->mRes_.ByteSize() < INT_MAX)
<< "Buffer too large for single transfer";
@@ -194,74 +195,78 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
};
// Wrapper around the read callback to place the callback on our queue
- Rendezvous::DoneCallback done_cb = [this, parsed, step_id, send_cb](
- const Status& status, const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
- if (!status.ok()) {
- CHECK(status.ok()) << "RecvLocalAsync was not ok, key: "
- << parsed.FullKey() << " step: " << step_id
- << " error message: " << status.error_message();
- return;
- }
-
- VLOG(3) << "MPI Sending tensor " << parsed.FullKey()
- << " @ step: " << step_id << std::endl;
-
- auto mpi_send_call = new MPISendTensorCall();
- mpi_send_call->Init(parsed, step_id, is_dead);
-
- Device* src_dev = nullptr;
- Status s = this->worker_env_2->device_mgr->LookupDevice(parsed.src_device,
- &src_dev);
- CHECK(s.ok()) << "src device not found";
-
- // Control if shape and data should be send together or if we can optimize
- // it in two different transfers, thereby reducing memory copies
- bool doOptimalTransfer = true;
- if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false;
- if (val.TotalBytes() < 1024) doOptimalTransfer = false;
-
- doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_;
-
- if (doOptimalTransfer) {
- // First send the Tensor description and in a follow up transfer the data
- mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype(
- val.dtype());
- val.shape().AsProto(mpi_send_call->mRes_.mutable_response()
- ->mutable_tensor()
- ->mutable_tensor_shape());
- mpi_send_call->mRes_.set_singlesend(false);
- } else {
- // Send the Tensor description and data in a single transfer
- if (src_dev->tensorflow_gpu_device_info() &&
- (!send_args.alloc_attrs.on_host())) {
- Notification n;
- GPUUtil::SetProtoFromGPU(
- val, src_dev, send_args.device_context,
- mpi_send_call->mRes_.mutable_response()->mutable_tensor(), is_dead,
- [&n, &s](const Status& s_) {
- s = s_;
- n.Notify();
- });
- n.WaitForNotification();
- } else {
- val.AsProtoTensorContent(
- mpi_send_call->mRes_.mutable_response()->mutable_tensor());
- }
- }
-
- std::function<MPISendTensorCall*()> res = std::bind(
- send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
-
- SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res));
-
- this->QueueSendRequest(req);
-
- // Wait for the notification that indicates the tensor has been
- // successfully transmitted to the remote process. Only needed if we
- // have not parsed the tensor to proto
- if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification();
- }; // done_cb
+ Rendezvous::DoneCallback done_cb =
+ [this, parsed, step_id, send_cb](
+ const Status& status, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
+ if (!status.ok()) {
+ CHECK(status.ok())
+ << "RecvLocalAsync was not ok, key: " << parsed.FullKey()
+ << " step: " << step_id
+ << " error message: " << status.error_message();
+ return;
+ }
+
+ VLOG(3) << "MPI Sending tensor " << parsed.FullKey()
+ << " @ step: " << step_id << std::endl;
+
+ auto mpi_send_call = new MPISendTensorCall();
+ mpi_send_call->Init(parsed, step_id, is_dead);
+
+ Device* src_dev = nullptr;
+ Status s = this->worker_env_2->device_mgr->LookupDevice(
+ parsed.src_device, &src_dev);
+ CHECK(s.ok()) << "src device not found";
+
+ // Control if shape and data should be send together or if we can
+ // optimize it in two different transfers, thereby reducing memory
+ // copies
+ bool doOptimalTransfer = true;
+ if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false;
+ if (val.TotalBytes() < 1024) doOptimalTransfer = false;
+
+ doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_;
+
+ if (doOptimalTransfer) {
+ // First send the Tensor description and in a follow up transfer the
+ // data
+ mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype(
+ val.dtype());
+ val.shape().AsProto(mpi_send_call->mRes_.mutable_response()
+ ->mutable_tensor()
+ ->mutable_tensor_shape());
+ mpi_send_call->mRes_.set_singlesend(false);
+ } else {
+ // Send the Tensor description and data in a single transfer
+ if (src_dev->tensorflow_gpu_device_info() &&
+ (!send_args.alloc_attrs.on_host())) {
+ Notification n;
+ GPUUtil::SetProtoFromGPU(
+ val, src_dev, send_args.device_context,
+ mpi_send_call->mRes_.mutable_response()->mutable_tensor(),
+ is_dead, [&n, &s](const Status& s_) {
+ s = s_;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ } else {
+ val.AsProtoTensorContent(
+ mpi_send_call->mRes_.mutable_response()->mutable_tensor());
+ }
+ }
+
+ std::function<MPISendTensorCall*()> res = std::bind(
+ send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
+
+ SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res));
+
+ this->QueueSendRequest(req);
+
+ // Wait for the notification that indicates the tensor has been
+ // successfully transmitted to the remote process. Only needed if we
+ // have not parsed the tensor to proto
+ if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification();
+ }; // done_cb
worker_env_2->compute_pool->Schedule([this, step_id, parsed, done_cb]() {
this->RecvLocalAsync(step_id, parsed, done_cb);
@@ -293,9 +298,8 @@ void MPIRendezvousMgr::MPIBackgroundThread() {
}
// Remove sends that have been completed
- active_sends.remove_if([](std::unique_ptr<MPISendTensorCall>& i) {
- return i->IsFinished();
- });
+ active_sends.remove_if(
+ [](std::unique_ptr<MPISendTensorCall>& i) { return i->IsFinished(); });
// send a Tensor request
RequestQueueEntry req;
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index ca42ee2f6d..e665922135 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -18,12 +18,12 @@ limitations under the License.
#ifdef TENSORFLOW_USE_MPI
-#include <queue>
-#include <thread>
#include <list>
-#include <string>
-#include <memory>
#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <thread>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -160,7 +160,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
private:
typedef std::function<MPISendTensorCall*(
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
- const Tensor&, const bool, MPISendTensorCall*)> MPIRecvTensorCallBack;
+ const Tensor&, const bool, MPISendTensorCall*)>
+ MPIRecvTensorCallBack;
typedef std::pair<std::string, std::function<void()>> RequestQueueEntry;
typedef std::pair<std::string, std::function<MPISendTensorCall*()>>
diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc
index d585c0565e..a31fa9ce0b 100644
--- a/tensorflow/contrib/mpi/mpi_server_lib.cc
+++ b/tensorflow/contrib/mpi/mpi_server_lib.cc
@@ -22,8 +22,8 @@ limitations under the License.
#include "grpc/support/alloc.h"
-#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h
index 45e21f2b25..fa297c28cb 100644
--- a/tensorflow/contrib/mpi/mpi_utils.h
+++ b/tensorflow/contrib/mpi/mpi_utils.h
@@ -18,8 +18,8 @@ limitations under the License.
#ifdef TENSORFLOW_USE_MPI
-#include <string>
#include <map>
+#include <string>
#include <vector>
#include "tensorflow/core/lib/strings/str_util.h"