diff options
author | Derek Murray <mrry@google.com> | 2018-04-19 18:12:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-19 18:15:41 -0700 |
commit | b7cca088e90b4c2a28c1038980aa09240584e382 (patch) | |
tree | 8ba76992e2b6f29fe3f5021d12c31afd23971d02 /tensorflow/core/distributed_runtime | |
parent | 7f87125dceb3c69c5fd1d0712c6c93cc4ceaa854 (diff) |
Respect any device filters in {Create,Delete}WorkerSessions().
This is another step towards enabling us to turn on explicit worker
sessions for all master sessions.
PiperOrigin-RevId: 193605565
Diffstat (limited to 'tensorflow/core/distributed_runtime')
5 files changed, 20 insertions, 8 deletions
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index f47502e844..288656e7f8 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -417,9 +417,13 @@ void Master::CreateSession(const CreateSessionRequest* req, SessionOptions options; options.config = req->config(); + std::vector<string> filtered_worker_list; + DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_, + worker_cache, &filtered_worker_list); + MasterSession* session = env_->master_session_factory( options, env_, std::move(remote_devices), std::move(worker_cache_ptr), - std::move(device_set)); + std::move(device_set), std::move(filtered_worker_list)); GraphDef* gdef = const_cast<CreateSessionRequest*>(req)->mutable_graph_def(); diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index 178c5b40ee..16f4d93c8b 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -83,7 +83,8 @@ struct MasterEnv { SessionOptions, MasterEnv*, std::unique_ptr<std::vector<std::unique_ptr<Device>>>, std::unique_ptr<WorkerCacheInterface>, - std::unique_ptr<DeviceSet> device_set)> + std::unique_ptr<DeviceSet> device_set, + std::vector<string> filtered_worker_list)> master_session_factory; std::function<Status(const WorkerCacheFactoryOptions&, diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 7868200fb4..ebe350d313 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -416,6 +416,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( if (!s.ok()) { for (Part& part : partitions_) { worker_cache_->ReleaseWorker(part.name, part.worker); + part.worker = nullptr; } return s; } @@ -1119,6 +1120,7 @@ MasterSession::MasterSession( std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<DeviceSet> device_set, + std::vector<string> filtered_worker_list, StatsPublisherFactory stats_publisher_factory) : session_opts_(opt), env_(env), @@ -1126,6 +1128,7 @@ MasterSession::MasterSession( remote_devs_(std::move(remote_devs)), worker_cache_(std::move(worker_cache)), devices_(std::move(device_set)), + filtered_worker_list_(std::move(filtered_worker_list)), stats_publisher_factory_(std::move(stats_publisher_factory)), graph_version_(0), run_graphs_(5), @@ -1183,9 +1186,8 @@ Status MasterSession::Create(GraphDef* graph_def, Status MasterSession::CreateWorkerSessions( const WorkerCacheFactoryOptions& options) { - std::vector<string> worker_names; + const std::vector<string> worker_names = filtered_worker_list_; WorkerCacheInterface* worker_cache = get_worker_cache(); - worker_cache->ListWorkers(&worker_names); struct WorkerGroup { // The worker name. (Not owned.) @@ -1263,8 +1265,7 @@ Status MasterSession::CreateWorkerSessions( Status MasterSession::DeleteWorkerSessions() { WorkerCacheInterface* worker_cache = get_worker_cache(); - std::vector<string> worker_names; - worker_cache->ListWorkers(&worker_names); + const std::vector<string>& worker_names = filtered_worker_list_; struct WorkerGroup { // The worker name. (Not owned.) diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index a05419904f..ec34e20b79 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -52,6 +52,7 @@ class MasterSession : public core::RefCounted { std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<DeviceSet> device_set, + std::vector<string> filtered_worker_list, StatsPublisherFactory stats_publisher_factory); // Initialize the MasterSession for "def". Must be called before Extend(), @@ -130,6 +131,10 @@ class MasterSession : public core::RefCounted { // The device set used by this session. std::unique_ptr<DeviceSet> devices_; + // The (partial device) names of remote worker tasks that this + // session will contact. + const std::vector<string> filtered_worker_list_; + StatsPublisherFactory stats_publisher_factory_; std::atomic_ulong last_access_time_usec_; @@ -212,7 +217,6 @@ class MasterSession : public core::RefCounted { // workers. Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); - // TODO(b/36574172): Always use Create/DeleteWorkerSession. bool should_delete_worker_sessions_ = false; Status DeleteWorkerSessions(); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index be19103582..488dcde9f5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -222,10 +222,12 @@ Status GrpcServer::Init( SessionOptions options, const MasterEnv* env, std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<WorkerCacheInterface> worker_cache, - std::unique_ptr<DeviceSet> device_set) { + std::unique_ptr<DeviceSet> device_set, + std::vector<string> filtered_worker_list) { options.config.MergeFrom(config); return new MasterSession(options, env, std::move(remote_devs), std::move(worker_cache), std::move(device_set), + std::move(filtered_worker_list), stats_factory); }; master_env_.worker_cache_factory = |