aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
blob: b070dd13dd6f18d27ef5498a00c1f43f225b95c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"

#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {

// string -> Tensor<string>
Tensor V(const string& content) {
  Tensor tensor(DT_STRING, TensorShape({}));
  tensor.scalar<string>()() = content;
  return tensor;
}

// Tensor<string> -> string
string V(const Tensor& tensor) {
  CHECK_EQ(tensor.dtype(), DT_STRING);
  CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
  return tensor.scalar<string>()();
}

Rendezvous::ParsedKey MakeKey(const string& s) {
  Rendezvous::ParsedKey key;
  CHECK(Rendezvous::ParseKey(s, &key).ok());
  return key;
}

namespace {
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
  void ListWorkers(std::vector<string>* workers) const override {}
  void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) const override {}
  WorkerInterface* CreateWorker(const string& target) override {
    return nullptr;
  }
  bool GetDeviceLocalityNonBlocking(const string& device,
                                    DeviceLocality* locality) override {
    return false;
  }
  void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
                              StatusCallback done) override {}
};
}  // namespace

class RpcRendezvousMgrTest : public ::testing::Test {
 protected:
  RpcRendezvousMgrTest()
      : cache_(new DummyWorkerCache),
        worker_session_("rpc_session", "/job:mnist/replica:1/task:2",
                        std::unique_ptr<WorkerCacheInterface>(cache_),
                        std::unique_ptr<DeviceMgr>(),
                        std::unique_ptr<GraphMgr>()),
        rmgr_(&env) {
    env.env = Env::Default();
  }

  DummyWorkerCache* cache_;  // Managed by worker_session.
  WorkerEnv env;

  WorkerSession worker_session_;
  RpcRendezvousMgr rmgr_;
};

TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
  const int64 step_id = 123;
  const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
      "/job:mnist/replica:1/task:2/cpu:0", 7890,
      "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
  {
    RemoteRendezvous* rendez = rmgr_.Find(step_id);
    TF_ASSERT_OK(rendez->Initialize(&worker_session_));
    core::ScopedUnref unref(rendez);
    Rendezvous::Args args;
    TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
  }
  {
    Tensor val(DT_FLOAT);
    bool val_dead = false;
    TF_ASSERT_OK(rmgr_.RecvLocal(step_id, key, &val, &val_dead));
    EXPECT_EQ(V(val), "peach");
  }
  rmgr_.Cleanup(step_id);
}

TEST_F(RpcRendezvousMgrTest, LocalAbort) {
  const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
      "/job:mnist/replica:1/task:2/cpu:0", 7890,
      "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
  {  // Explicit Abort().
    const int64 step_id = 123;
    RemoteRendezvous* rendez = rmgr_.Find(step_id);
    core::ScopedUnref unref(rendez);
    SchedClosure([this, rendez]() {
      env.env->SleepForMicroseconds(100 * 1000);
      rendez->StartAbort(errors::Aborted(""));
    });
    Tensor val(DT_STRING);
    bool val_dead = false;
    Rendezvous::Args args;
    TF_ASSERT_OK(rendez->Initialize(&worker_session_));
    EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
  }
  {  // Cleanup causes Abort().
    const int64 step_id = 321;
    RemoteRendezvous* rendez = rmgr_.Find(step_id);
    core::ScopedUnref unref(rendez);
    SchedClosure([this, step_id]() {
      env.env->SleepForMicroseconds(100 * 1000);
      rmgr_.Cleanup(step_id);
    });
    Tensor val(DT_STRING);
    bool val_dead = false;
    Rendezvous::Args args;
    TF_ASSERT_OK(rendez->Initialize(&worker_session_));
    EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
  }
}

TEST_F(RpcRendezvousMgrTest, CleanupAll) {
  const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
      "/job:mnist/replica:1/task:2/cpu:0", 7890,
      "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
  {
    const int64 step_id = 123;
    RemoteRendezvous* rendez = rmgr_.Find(step_id);
    TF_ASSERT_OK(rendez->Initialize(&worker_session_));
    core::ScopedUnref unref(rendez);
    Rendezvous::Args args;
    TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
    rmgr_.CleanupAll();
    Tensor val(DT_STRING);
    bool val_dead = false;
    EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
  }
}

class DummyDeviceContext : public DeviceContext {
 public:
  explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
  ~DummyDeviceContext() override {}
  int stream_id() const { return stream_id_; }

 private:
  const int stream_id_;
};

TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
  DummyDeviceContext* dc = new DummyDeviceContext(123);

  const int64 step_id = 123;
  const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
      "/job:mnist/replica:1/task:2/cpu:0", 7890,
      "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
  {
    RemoteRendezvous* rendez = rmgr_.Find(step_id);
    core::ScopedUnref unref(rendez);
    Rendezvous::Args args;
    args.device_context = dc;
    TF_ASSERT_OK(rendez->Initialize(&worker_session_));
    TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
  }
  {
    Notification n;
    rmgr_.RecvLocalAsync(
        step_id, key,
        [&n](const Status& s, const Rendezvous::Args send_args,
             const Rendezvous::Args recv_args, const Tensor& val,
             bool is_dead) {
          auto send_dev_context =
              static_cast<DummyDeviceContext*>(send_args.device_context);
          CHECK_EQ(123, send_dev_context->stream_id());
          CHECK_EQ(V(val), "peach");
          n.Notify();
        });
    n.WaitForNotification();
  }
  rmgr_.Cleanup(step_id);
  dc->Unref();
}

// NOTE: Remote Send/Recv is better tested in worker_test.cc

}  // namespace tensorflow