diff options
author | Derek Murray <mrry@google.com> | 2018-08-21 08:18:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 08:25:13 -0700 |
commit | 5b456c9ab567c7d9262c57f32693ff33a87946e6 (patch) | |
tree | bb2b74a5b11ff04e9b3319ca017e53b1b48f5c40 /tensorflow/core/distributed_runtime/rpc | |
parent | aeab291563b0b4cc75c0f5fc73610a6595780570 (diff) |
[Distributed] Add methods to WorkerCache that selectively list workers by job name.
PiperOrigin-RevId: 209597829
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc')
5 files changed, 74 insertions, 14 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc index b7eb3c9015..456c30ecf4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc @@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache { } } + void ListWorkersInJob(const string& job_name, + std::vector<string>* workers) override { + for (GrpcChannelCache* cache : caches_) { + cache->ListWorkersInJob(job_name, workers); + } + } + string TranslateTask(const string& target) override { mutex_lock l(mu_); // could use reader lock GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target); @@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache { } } + void ListWorkersInJob(const string& job_name, + std::vector<string>* workers) override { + if (job_name == job_id_) { + ListWorkers(workers); + } + } + string TranslateTask(const string& target) override { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(target, &parsed)) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h index 4861cdb691..6fa99d7b14 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -66,6 +66,8 @@ class GrpcChannelCache { // /job:<job identifier>/task:<task id> // e.g. /job:mnist/task:2 virtual void ListWorkers(std::vector<string>* workers) = 0; + virtual void ListWorkersInJob(const string& job_name, + std::vector<string>* workers) = 0; // If found, returns a gRPC channel that is connected to the remote // worker named by 'target'. 'target' is of the following diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc index f07a5a0974..a814ef85e2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc @@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) { EXPECT_NE(d_4_1.get(), e_5_2.get()); } - std::vector<string> workers; - cc->ListWorkers(&workers); - EXPECT_EQ(std::vector<string>( - {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1", - "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3", - "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}), - workers); + { + std::vector<string> workers; + cc->ListWorkers(&workers); + EXPECT_EQ( + std::vector<string>( + {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1", + "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}), + workers); + } + + { + std::vector<string> workers; + cc->ListWorkersInJob("mnist", &workers); + EXPECT_EQ( + std::vector<string>( + {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1", + "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}), + workers); + } + + { + std::vector<string> workers; + cc->ListWorkersInJob("other", &workers); + EXPECT_TRUE(workers.empty()); + } } TEST(GrpcChannelTest, SparseHostPorts) { @@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) { EXPECT_NE(d_4_1.get(), e_5_2.get()); } - std::vector<string> workers; - cc->ListWorkers(&workers); - std::sort(workers.begin(), workers.end()); - EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0", - "/job:mnist/replica:0/task:3", - "/job:mnist/replica:0/task:4"}), - workers); + { + std::vector<string> workers; + cc->ListWorkers(&workers); + std::sort(workers.begin(), workers.end()); + EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0", + "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4"}), + workers); + } + + { + std::vector<string> workers; + cc->ListWorkersInJob("mnist", &workers); + EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0", + "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4"}), + workers); + } + + { + std::vector<string> workers; + cc->ListWorkersInJob("other", &workers); + EXPECT_TRUE(workers.empty()); + } } TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index b9f21ea211..e1541db69b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial { channel_cache_->ListWorkers(workers); } + void ListWorkersInJob(const string& job_name, + std::vector<string>* workers) const override { + channel_cache_->ListWorkersInJob(job_name, workers); + } + WorkerInterface* CreateWorker(const string& target) override { if (target == local_target_) { return local_worker_; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 25ff6512a0..b070dd13dd 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -50,6 +50,8 @@ namespace { // Fake cache implementation for WorkerEnv. class DummyWorkerCache : public WorkerCacheInterface { void ListWorkers(std::vector<string>* workers) const override {} + void ListWorkersInJob(const string& job_name, + std::vector<string>* workers) const override {} WorkerInterface* CreateWorker(const string& target) override { return nullptr; } |