aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-19 18:12:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 18:15:41 -0700
commitb7cca088e90b4c2a28c1038980aa09240584e382 (patch)
tree8ba76992e2b6f29fe3f5021d12c31afd23971d02 /tensorflow/core/distributed_runtime
parent7f87125dceb3c69c5fd1d0712c6c93cc4ceaa854 (diff)
Respect any device filters in {Create,Delete}WorkerSessions().
This is another step towards enabling us to turn on explicit worker sessions for all master sessions. PiperOrigin-RevId: 193605565
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/master.cc6
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h3
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc9
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc4
5 files changed, 20 insertions, 8 deletions
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index f47502e844..288656e7f8 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -417,9 +417,13 @@ void Master::CreateSession(const CreateSessionRequest* req,
SessionOptions options;
options.config = req->config();
+ std::vector<string> filtered_worker_list;
+ DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
+ worker_cache, &filtered_worker_list);
+
MasterSession* session = env_->master_session_factory(
options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
- std::move(device_set));
+ std::move(device_set), std::move(filtered_worker_list));
GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index 178c5b40ee..16f4d93c8b 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -83,7 +83,8 @@ struct MasterEnv {
SessionOptions, MasterEnv*,
std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
std::unique_ptr<WorkerCacheInterface>,
- std::unique_ptr<DeviceSet> device_set)>
+ std::unique_ptr<DeviceSet> device_set,
+ std::vector<string> filtered_worker_list)>
master_session_factory;
std::function<Status(const WorkerCacheFactoryOptions&,
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 7868200fb4..ebe350d313 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -416,6 +416,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
if (!s.ok()) {
for (Part& part : partitions_) {
worker_cache_->ReleaseWorker(part.name, part.worker);
+ part.worker = nullptr;
}
return s;
}
@@ -1119,6 +1120,7 @@ MasterSession::MasterSession(
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
+ std::vector<string> filtered_worker_list,
StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt),
env_(env),
@@ -1126,6 +1128,7 @@ MasterSession::MasterSession(
remote_devs_(std::move(remote_devs)),
worker_cache_(std::move(worker_cache)),
devices_(std::move(device_set)),
+ filtered_worker_list_(std::move(filtered_worker_list)),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
run_graphs_(5),
@@ -1183,9 +1186,8 @@ Status MasterSession::Create(GraphDef* graph_def,
Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
- std::vector<string> worker_names;
+ const std::vector<string> worker_names = filtered_worker_list_;
WorkerCacheInterface* worker_cache = get_worker_cache();
- worker_cache->ListWorkers(&worker_names);
struct WorkerGroup {
// The worker name. (Not owned.)
@@ -1263,8 +1265,7 @@ Status MasterSession::CreateWorkerSessions(
Status MasterSession::DeleteWorkerSessions() {
WorkerCacheInterface* worker_cache = get_worker_cache();
- std::vector<string> worker_names;
- worker_cache->ListWorkers(&worker_names);
+ const std::vector<string>& worker_names = filtered_worker_list_;
struct WorkerGroup {
// The worker name. (Not owned.)
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index a05419904f..ec34e20b79 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -52,6 +52,7 @@ class MasterSession : public core::RefCounted {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
+ std::vector<string> filtered_worker_list,
StatsPublisherFactory stats_publisher_factory);
// Initialize the MasterSession for "def". Must be called before Extend(),
@@ -130,6 +131,10 @@ class MasterSession : public core::RefCounted {
// The device set used by this session.
std::unique_ptr<DeviceSet> devices_;
+ // The (partial device) names of remote worker tasks that this
+ // session will contact.
+ const std::vector<string> filtered_worker_list_;
+
StatsPublisherFactory stats_publisher_factory_;
std::atomic_ulong last_access_time_usec_;
@@ -212,7 +217,6 @@ class MasterSession : public core::RefCounted {
// workers.
Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
- // TODO(b/36574172): Always use Create/DeleteWorkerSession.
bool should_delete_worker_sessions_ = false;
Status DeleteWorkerSessions();
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index be19103582..488dcde9f5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -222,10 +222,12 @@ Status GrpcServer::Init(
SessionOptions options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
- std::unique_ptr<DeviceSet> device_set) {
+ std::unique_ptr<DeviceSet> device_set,
+ std::vector<string> filtered_worker_list) {
options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs),
std::move(worker_cache), std::move(device_set),
+ std::move(filtered_worker_list),
stats_factory);
};
master_env_.worker_cache_factory =