aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc')
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc356
1 files changed, 356 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
new file mode 100644
index 0000000000..a552f81f58
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -0,0 +1,356 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+// The only interesting method on CollectiveRemoteAccessDistributed
+// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
+// issues a RecvBufAsync call against a WorkerInterface. That's all
+// that's tested here. Note that RecvFromPeer can do a
+// DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
+// for the RecvBufAsync.
+
+namespace tensorflow {
+namespace {
+
+static Device* NewDevice(const string& type, const string& name) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ attr.mutable_locality()->set_numa_node(3); // a non-default value
+ return new FakeDevice(attr);
+}
+
+static int64 kStepId = 123;
+
+class FakeWorker : public TestWorkerInterface {
+ public:
+ FakeWorker(const string& name, DeviceMgr* dev_mgr,
+ DeviceResolverDistributed* dres)
+ : name_(name),
+ device_mgr_(dev_mgr),
+ device_resolver_(dres),
+ buf_rendezvous_(kStepId) {}
+
+ // Direct access to a BufRendezvous that holds whatever the remote
+ // worker is supposed to have.
+ BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
+
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ std::vector<DeviceAttributes> dev_attr;
+ device_mgr_->ListDeviceAttributes(&dev_attr);
+ for (const auto& da : dev_attr) {
+ *response->add_device_attributes() = da;
+ }
+ done(Status::OK());
+ }
+
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ opts->SetCancelCallback([this]() {
+ // Within this test the call is satisfied by a process-local
+ // BufRendezvous table. In real application the BufRendezvous
+ // would be on the other side of a network hop, so call
+ // BufRendezvous::StartAbort() from a separate thread to be
+ // more consistent with that situation and avoid mutex deadlock.
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100);
+ buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
+ });
+ });
+ buf_rendezvous_.ConsumeBuf(
+ request->buf_rendezvous_key(),
+ [this, opts, request, response, done](const Status& s,
+ BufRendezvous::Hook* h) {
+ if (s.ok()) {
+ opts->ClearCancelCallback();
+ // Since this is not really RDMA into pre-allocated memory send the
+ // bytes in the response.
+ RecvBufRespExtra extra;
+ int64 num_bytes = h->prod_value->TotalBytes();
+ extra.set_tensor_content(string(
+ reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
+ num_bytes));
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ done(s);
+ if (h) BufRendezvous::DoneWithHook(h);
+ });
+ }
+
+ private:
+ string name_;
+ DeviceMgr* device_mgr_;
+ DeviceResolverDistributed* device_resolver_;
+ BufRendezvous buf_rendezvous_;
+};
+
+class FakeCache : public TestWorkerCache {
+ public:
+ // Override the Locality methods to actually pass through to the
+ // worker.
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return false;
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ string task_name;
+ string dev_part;
+ if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
+ done(errors::Internal("failed to parse device name"));
+ return;
+ }
+ auto it = workers_.find(task_name);
+ if (it == workers_.end()) {
+ done(errors::Internal("failed to find worker ", task_name));
+ return;
+ }
+ WorkerInterface* wi = it->second;
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ Notification note;
+ Status status = wi->GetStatus(&req, &resp);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ for (const auto& it : resp.device_attributes()) {
+ if (it.name() == device) {
+ *locality = it.locality();
+ done(Status::OK());
+ return;
+ }
+ }
+ done(errors::Internal("device not found: ", device));
+ }
+};
+
+class CollRMADistTest : public ::testing::Test {
+ protected:
+ CollRMADistTest() {}
+
+ ~CollRMADistTest() override {
+ for (DeviceMgr* dm : device_mgrs_) {
+ delete dm;
+ }
+ for (auto it : dev_resolvers_) {
+ delete it.second;
+ }
+ for (FakeWorker* w : workers_) {
+ delete w;
+ }
+ }
+
+ void SetUp() override {
+ const int num_workers = 2;
+ const int num_devices = 1;
+ string device_type = "CPU";
+ ConfigProto config;
+ string dev0_worker_name;
+ for (int w = 0; w < num_workers; ++w) {
+ string name = strings::StrCat("/job:worker/replica:0/task:", w);
+ if (w == 0) {
+ dev0_worker_name = name;
+ // TODO(tucker): Change to use config when available.
+ // config.set_collective_group_leader(name);
+ }
+ DefineWorker(config, name, device_type, num_devices);
+ }
+ // All tests simulate requests from worker 0 to worker 1.
+ rma_.reset(new CollectiveRemoteAccessDistributed(
+ device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId));
+
+ const int kNumElts = 8;
+ expected_value_ = Tensor(DT_FLOAT, {kNumElts});
+ to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
+ auto exp_alias = expected_value_.flat<float>();
+ auto to_alias = to_tensor_.flat<float>();
+ for (int i = 0; i < kNumElts; ++i) {
+ exp_alias(i) = i;
+ to_alias(i) = -1;
+ }
+ }
+
+ void DefineWorker(const ConfigProto& config, const string& worker_name,
+ const string& device_type, int num_devices) {
+ std::vector<Device*> devices;
+ for (int i = 0; i < num_devices; ++i) {
+ devices.push_back(NewDevice(
+ device_type,
+ strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+ }
+ DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ device_mgrs_.push_back(dev_mgr);
+ std::vector<string>* dv = &dev_by_task_[worker_name];
+ for (auto d : devices) {
+ dv->push_back(d->name());
+ }
+ DeviceResolverDistributed* dev_res =
+ new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+ dev_resolvers_[worker_name] = dev_res;
+ FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+ workers_.push_back(fw);
+ wc_.AddWorker(worker_name, fw);
+ }
+
+ void ValidateResultTensor() {
+ ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
+ for (int i = 0; i < to_tensor_.NumElements(); ++i) {
+ EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
+ to_tensor_.flat<float>()(i));
+ }
+ }
+
+ FakeCache wc_;
+ CancellationManager cm_;
+ std::vector<DeviceMgr*> device_mgrs_;
+ std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
+ std::unordered_map<string, std::vector<string>> dev_by_task_;
+ std::vector<FakeWorker*> workers_;
+ std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
+ mutex mu_;
+ int num_done_ GUARDED_BY(mu_);
+ condition_variable done_;
+ Tensor expected_value_;
+ Tensor to_tensor_;
+ CallOptions opts_;
+ DeviceLocality device_locality_;
+ AllocatorAttributes alloc_attr_;
+};
+
+TEST_F(CollRMADistTest, ProdFirstOK) {
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+ AllocatorAttributes(),
+ [this, &producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ });
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ TF_EXPECT_OK(consumer_status);
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstOK) {
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+ AllocatorAttributes(),
+ [this, &producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ TF_EXPECT_OK(consumer_status);
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstAbort) {
+ Notification consumer_note;
+ Status consumer_status;
+ const string kBufKey = "fake_buf_key";
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ rma_->StartAbort(errors::Internal("Deliberate Failure"));
+ consumer_note.WaitForNotification();
+ EXPECT_EQ(consumer_status.error_message(), "Cancelled");
+}
+
+} // namespace
+} // namespace tensorflow