/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" #include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" namespace tensorflow { namespace { class RecvBufCall : public CancellableCall { public: RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, const string& key, Device* to_device, DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, const DeviceLocality& server_locality, CancellationManager* cancel_mgr, WorkerCacheInterface* wc) : CancellableCall(cancel_mgr, peer_task, wc) { req_.set_step_id(step_id); req_.set_buf_rendezvous_key(key); *req_.mutable_client_locality() = client_locality; *req_.mutable_server_locality() = server_locality; req_.set_num_bytes(to_tensor->TotalBytes()); req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); req_.set_src_device(peer_device); req_.set_dst_device(to_device->name()); req_.set_request_id(GetUniqueRequestId()); } ~RecvBufCall() override {} void IssueCall(const StatusCallback& done) override { wi_->RecvBufAsync(&opts_, &req_, &resp_, done); } RecvBufRequest req_; RecvBufResponse resp_; }; } // namespace void CollectiveRemoteAccessDistributed::RecvFromPeer( const string& peer_device, const string& peer_task, bool peer_is_local, const string& key, Device* to_device, DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, const StatusCallback& done) { if (peer_is_local) { CollectiveRemoteAccessLocal::RecvFromPeer( peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, done); return; } // State that needs to be threaded through a couple of async calls // in order to make this function completely non-blocking. struct State { DeviceLocality server_locality; std::unique_ptr call; }; State* state = new State; // Logic to be executed on the RecvBufAsync callback. auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, to_device_ctx, to_tensor, dev_to_dev_stream_index, done](const Status& s) { if (s.ok()) { // In this generic implementation the bytes come back in the // RPC response protobuf rather than via RDMA so we need to copy // them into the destination tensor here. RecvBufRespExtra extra; state->call->resp_.transport_options().UnpackTo(&extra); int64 num_bytes = extra.tensor_content().size(); if (num_bytes != to_tensor->TotalBytes()) { done(errors::Internal("RecvBufResponse returned ", num_bytes, " bytes where to_tensor expected ", to_tensor->TotalBytes())); delete state; return; } if (to_device->tensorflow_gpu_device_info()) { // Move the bytes into a CPU tensor then use tensor-to-tensor copy. // Use GPU-registered memory for the CPU tensor so the transfer // goes faster. Device* cpu_dev = nullptr; Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev); if (!status.ok()) { done(status); delete state; return; } AllocatorAttributes cpu_attr; cpu_attr.set_gpu_compatible(true); Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), to_tensor->dtype(), to_tensor->shape()); memcpy(DMAHelper::base(cpu_tensor), extra.tensor_content().data(), num_bytes); // Then copy it to the GPU. CopyTensor::ViaDMA("", // edge name (non-existent) nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev, to_device, cpu_attr, to_alloc_attr, cpu_tensor, to_tensor, dev_to_dev_stream_index, [this, cpu_tensor, done](const Status& s) { delete cpu_tensor; // This callback must not block, so execute // done in another thread. SchedClosure([s, done] { done(s); }); }); delete state; return; } else { // CPU device memcpy(DMAHelper::base(to_tensor), extra.tensor_content().data(), num_bytes); } } if (!s.ok() && errors::IsFailedPrecondition(s)) { dev_resolver_->ClearTask(peer_task); } delete state; done(s); }; // Logic to execute once we have the device locality for the server-side // device. auto dev_locality_callback = [this, state, peer_device, peer_task, key, to_device, to_device_ctx, to_alloc_attr, to_tensor, client_locality, recv_buf_callback](const Status& s) { if (!s.ok()) { recv_buf_callback(s); } else { state->call.reset(new RecvBufCall( step_id_, peer_device, peer_task, key, to_device, to_device_ctx, to_alloc_attr, to_tensor, client_locality, state->server_locality, &cancel_mgr_, worker_cache_)); state->call->Start(recv_buf_callback); } }; dev_resolver_->GetLocalityAsync( peer_device, peer_task, &state->server_locality, dev_locality_callback); } void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { CollectiveRemoteAccessLocal::StartAbort(s); cancel_mgr_.StartCancel(); } } // namespace tensorflow