aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-21 08:18:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 08:25:13 -0700
commit5b456c9ab567c7d9262c57f32693ff33a87946e6 (patch)
treebb2b74a5b11ff04e9b3319ca017e53b1b48f5c40 /tensorflow/core/distributed_runtime
parentaeab291563b0b4cc75c0f5fc73610a6595780570 (diff)
[Distributed] Add methods to WorkerCache that selectively list workers by job name.
PiperOrigin-RevId: 209597829
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/master.cc51
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc65
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h14
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache.h2
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_wrapper.h4
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc5
11 files changed, 147 insertions, 18 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index b2192c5a80..37029f3f1a 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -562,6 +562,7 @@ cc_library(
deps = [
":worker_cache",
":worker_interface",
+ "//tensorflow/core:framework",
],
)
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index a48f734d3e..269f620e42 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -167,13 +168,55 @@ class DeviceFinder {
}
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
- CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
- const string& local_device_name = env_->local_devices[0]->name();
- std::vector<string> workers;
- worker_cache->ListWorkers(&workers);
if (filters_.empty()) {
+ // If no filters were specified, we list all known workers in
+ // `worker_cache`.
+ std::vector<string> workers;
+ worker_cache->ListWorkers(&workers);
std::swap(workers, targets_);
} else {
+ // When applying filters, we must include the local worker, even if it
+ // does not match any of the filters.
+ CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
+ const string& local_device_name = env_->local_devices[0]->name();
+ DeviceNameUtils::ParsedName local_parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(local_device_name,
+ &local_parsed_name));
+ bool all_filters_have_job = true;
+ std::unordered_set<string> filter_job_names({local_parsed_name.job});
+ for (const DeviceNameUtils::ParsedName& filter : filters_) {
+ all_filters_have_job = all_filters_have_job && filter.has_job;
+ if (filter.has_job) {
+ filter_job_names.insert(filter.job);
+ }
+ }
+
+ std::vector<string> workers;
+ if (all_filters_have_job) {
+ // If all of the device filters have a job specified, then we only need
+ // to list the workers in the jobs named in the filter, because a worker
+ // in any other job would not match any filter.
+ for (const string& job_name : filter_job_names) {
+ VLOG(2) << "Selectively listing workers in job: " << job_name;
+ std::vector<string> workers_in_job;
+ worker_cache->ListWorkersInJob(job_name, &workers_in_job);
+ workers.insert(workers.end(), workers_in_job.begin(),
+ workers_in_job.end());
+ }
+ } else {
+ // If any of the device filters does not have a job specified, then we
+ // must list the workers from all jobs.
+ VLOG(2) << "Listing workers in all jobs because some device "
+ << "filter has no job specified. Filters were:";
+ if (device_filters.empty()) {
+ VLOG(2) << "- <NO FILTERS>";
+ } else {
+ for (const string& filter : device_filters) {
+ VLOG(2) << "- " << filter;
+ }
+ }
+ worker_cache->ListWorkers(&workers);
+ }
for (const string& name : workers) {
if (MatchFilters(name) ||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index b7eb3c9015..456c30ecf4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ for (GrpcChannelCache* cache : caches_) {
+ cache->ListWorkersInJob(job_name, workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
@@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ if (job_name == job_id_) {
+ ListWorkers(workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 4861cdb691..6fa99d7b14 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -66,6 +66,8 @@ class GrpcChannelCache {
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) = 0;
// If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index f07a5a0974..a814ef85e2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- EXPECT_EQ(std::vector<string>(
- {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
- "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, SparseHostPorts) {
@@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- std::sort(workers.begin(), workers.end());
- EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
- "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ std::sort(workers.begin(), workers.end());
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
index b9f21ea211..e1541db69b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
channel_cache_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ channel_cache_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index 25ff6512a0..b070dd13dd 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -50,6 +50,8 @@ namespace {
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {}
WorkerInterface* CreateWorker(const string& target) override {
return nullptr;
}
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index 48d83845dd..88a97da34d 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ workers->clear();
+ for (auto it : workers_) {
+ DeviceNameUtils::ParsedName device_name;
+ CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
+ CHECK(device_name.has_job);
+ if (job_name == device_name.job) {
+ workers->push_back(it.first);
+ }
+ }
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
auto it = workers_.find(target);
if (it != workers_.end()) {
diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h
index 8521f8956b..0c8575b4d5 100644
--- a/tensorflow/core/distributed_runtime/worker_cache.h
+++ b/tensorflow/core/distributed_runtime/worker_cache.h
@@ -36,6 +36,8 @@ class WorkerCacheInterface {
// Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) const = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const = 0;
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
index 43c3b6285b..1f309b4361 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
+++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
@@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
virtual void ListWorkers(std::vector<string>* workers) const {
return wrapped_->ListWorkers(workers);
}
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const {
+ return wrapped_->ListWorkersInJob(job_name, workers);
+ }
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index ca6dc1b1de..c7d0c6b7f3 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
wrapped_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ wrapped_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);