/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" #include "google/protobuf/any.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/test_utils.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/util/device_name_utils.h" // The only interesting method on CollectiveRemoteAccessDistributed // that's not on CollectiveRemoteAccessLocal is RecvFromPeer which // issues a RecvBufAsync call against a WorkerInterface. That's all // that's tested here. Note that RecvFromPeer can do a // DeviceResolverInterface::GetDeviceLocalityAsync call in preparation // for the RecvBufAsync. namespace tensorflow { namespace { static Device* NewDevice(const string& type, const string& name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; attr.set_name(name); attr.set_device_type(type); attr.mutable_locality()->set_numa_node(3); // a non-default value return new FakeDevice(attr); } static int64 kStepId = 123; class FakeWorker : public TestWorkerInterface { public: FakeWorker(const string& name, DeviceMgr* dev_mgr, DeviceResolverDistributed* dres) : name_(name), device_mgr_(dev_mgr), device_resolver_(dres), buf_rendezvous_(kStepId) {} // Direct access to a BufRendezvous that holds whatever the remote // worker is supposed to have. BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; } void GetStatusAsync(const GetStatusRequest* request, GetStatusResponse* response, StatusCallback done) override { std::vector dev_attr; device_mgr_->ListDeviceAttributes(&dev_attr); for (const auto& da : dev_attr) { *response->add_device_attributes() = da; } done(Status::OK()); } void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, RecvBufResponse* response, StatusCallback done) override { opts->SetCancelCallback([this]() { // Within this test the call is satisfied by a process-local // BufRendezvous table. In real application the BufRendezvous // would be on the other side of a network hop, so call // BufRendezvous::StartAbort() from a separate thread to be // more consistent with that situation and avoid mutex deadlock. SchedClosure([this]() { Env::Default()->SleepForMicroseconds(100); buf_rendezvous_.StartAbort(errors::Internal("Cancelled")); }); }); buf_rendezvous_.ConsumeBuf( request->buf_rendezvous_key(), [this, opts, request, response, done](const Status& s, BufRendezvous::Hook* h) { if (s.ok()) { opts->ClearCancelCallback(); // Since this is not really RDMA into pre-allocated memory send the // bytes in the response. RecvBufRespExtra extra; int64 num_bytes = h->prod_value->TotalBytes(); extra.set_tensor_content(string( reinterpret_cast(DMAHelper::base(h->prod_value)), num_bytes)); response->mutable_transport_options()->PackFrom(extra); } done(s); if (h) BufRendezvous::DoneWithHook(h); }); } private: string name_; DeviceMgr* device_mgr_; DeviceResolverDistributed* device_resolver_; BufRendezvous buf_rendezvous_; }; class FakeCache : public TestWorkerCache { public: // Override the Locality methods to actually pass through to the // worker. bool GetDeviceLocalityNonBlocking(const string& device, DeviceLocality* locality) override { return false; } void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, StatusCallback done) override { string task_name; string dev_part; if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) { done(errors::Internal("failed to parse device name")); return; } auto it = workers_.find(task_name); if (it == workers_.end()) { done(errors::Internal("failed to find worker ", task_name)); return; } WorkerInterface* wi = it->second; GetStatusRequest req; GetStatusResponse resp; Notification note; Status status = wi->GetStatus(&req, &resp); if (!status.ok()) { done(status); return; } for (const auto& it : resp.device_attributes()) { if (it.name() == device) { *locality = it.locality(); done(Status::OK()); return; } } done(errors::Internal("device not found: ", device)); } }; class CollRMADistTest : public ::testing::Test { protected: CollRMADistTest() {} ~CollRMADistTest() override { for (DeviceMgr* dm : device_mgrs_) { delete dm; } for (auto it : dev_resolvers_) { delete it.second; } for (FakeWorker* w : workers_) { delete w; } } void SetUp() override { const int num_workers = 2; const int num_devices = 1; string device_type = "CPU"; ConfigProto config; string dev0_worker_name; for (int w = 0; w < num_workers; ++w) { string name = strings::StrCat("/job:worker/replica:0/task:", w); if (w == 0) { dev0_worker_name = name; // TODO(tucker): Change to use config when available. // config.set_collective_group_leader(name); } DefineWorker(config, name, device_type, num_devices); } // All tests simulate requests from worker 0 to worker 1. rma_.reset(new CollectiveRemoteAccessDistributed( device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId)); const int kNumElts = 8; expected_value_ = Tensor(DT_FLOAT, {kNumElts}); to_tensor_ = Tensor(DT_FLOAT, {kNumElts}); auto exp_alias = expected_value_.flat(); auto to_alias = to_tensor_.flat(); for (int i = 0; i < kNumElts; ++i) { exp_alias(i) = i; to_alias(i) = -1; } } void DefineWorker(const ConfigProto& config, const string& worker_name, const string& device_type, int num_devices) { std::vector devices; for (int i = 0; i < num_devices; ++i) { devices.push_back(NewDevice( device_type, strings::StrCat(worker_name, "/device:", device_type, ":", i))); } DeviceMgr* dev_mgr = new DeviceMgr(devices); device_mgrs_.push_back(dev_mgr); std::vector* dv = &dev_by_task_[worker_name]; for (auto d : devices) { dv->push_back(d->name()); } DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr, &wc_, worker_name); dev_resolvers_[worker_name] = dev_res; FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res); workers_.push_back(fw); wc_.AddWorker(worker_name, fw); } void ValidateResultTensor() { ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements()); for (int i = 0; i < to_tensor_.NumElements(); ++i) { EXPECT_FLOAT_EQ(expected_value_.flat()(i), to_tensor_.flat()(i)); } } FakeCache wc_; CancellationManager cm_; std::vector device_mgrs_; std::unordered_map dev_resolvers_; std::unordered_map> dev_by_task_; std::vector workers_; std::unique_ptr rma_; mutex mu_; int num_done_ GUARDED_BY(mu_); condition_variable done_; Tensor expected_value_; Tensor to_tensor_; CallOptions opts_; DeviceLocality device_locality_; AllocatorAttributes alloc_attr_; }; TEST_F(CollRMADistTest, ProdFirstOK) { Notification consumer_note; Notification producer_note; Status consumer_status; Status producer_status; FakeWorker* wi = workers_[1]; const string kBufKey = "fake_buf_key"; wi->buf_rendezvous()->ProvideBuf( kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, AllocatorAttributes(), [this, &producer_note, &producer_status](const Status& s) { producer_status.Update(s); producer_note.Notify(); }); Status status; Device* dst_device = nullptr; string dev_name = "CPU:0"; TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); DeviceContext* to_device_ctx = nullptr; rma_->RecvFromPeer( "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev "/job:worker/replica:0/task:1", // peer_task false, // peer_is_local kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, [this, &consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); }); consumer_note.WaitForNotification(); TF_EXPECT_OK(consumer_status); producer_note.WaitForNotification(); TF_EXPECT_OK(producer_status); ValidateResultTensor(); } TEST_F(CollRMADistTest, ConsFirstOK) { Notification consumer_note; Notification producer_note; Status consumer_status; Status producer_status; FakeWorker* wi = workers_[1]; const string kBufKey = "fake_buf_key"; Status status; Device* dst_device = nullptr; string dev_name = "CPU:0"; TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); DeviceContext* to_device_ctx = nullptr; rma_->RecvFromPeer( "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev "/job:worker/replica:0/task:1", // peer_task false, // peer_is_local kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, [this, &consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); }); wi->buf_rendezvous()->ProvideBuf( kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, AllocatorAttributes(), [this, &producer_note, &producer_status](const Status& s) { producer_status.Update(s); producer_note.Notify(); }); consumer_note.WaitForNotification(); TF_EXPECT_OK(consumer_status); producer_note.WaitForNotification(); TF_EXPECT_OK(producer_status); ValidateResultTensor(); } TEST_F(CollRMADistTest, ConsFirstAbort) { Notification consumer_note; Status consumer_status; const string kBufKey = "fake_buf_key"; Status status; Device* dst_device = nullptr; string dev_name = "CPU:0"; TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); DeviceContext* to_device_ctx = nullptr; rma_->RecvFromPeer( "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev "/job:worker/replica:0/task:1", // peer_task false, // peer_is_local kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, [this, &consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); }); rma_->StartAbort(errors::Internal("Deliberate Failure")); consumer_note.WaitForNotification(); EXPECT_EQ(consumer_status.error_message(), "Cancelled"); } } // namespace } // namespace tensorflow