diff options
12 files changed, 153 insertions, 41 deletions
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index 000a03da5d..6edc2ec5ed 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -145,6 +145,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate( RegisterGraphRequest req; req.set_session_handle(worker_session_->session_name); + req.set_create_worker_session_called(create_worker_session_called_); *req.mutable_graph_def() = gdef; req.mutable_graph_options() ->mutable_optimizer_options() @@ -182,6 +183,7 @@ void ClusterFunctionLibraryRuntime::Run( RunGraphRequest* req = new RunGraphRequest; req->set_session_handle(worker_session_->session_name); + req->set_create_worker_session_called(create_worker_session_called_); req->set_graph_handle(function_data->graph_handle); // Borrowed from master_session.cc const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index d3ca350e36..1ea0a3ad51 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -27,8 +27,10 @@ struct WorkerSession; // functions across processes by making RPCs. class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { public: - ClusterFunctionLibraryRuntime(WorkerSession* worker_session) - : worker_session_(worker_session) {} + ClusterFunctionLibraryRuntime(WorkerSession* worker_session, + bool create_worker_session_called) + : worker_session_(worker_session), + create_worker_session_called_(create_worker_session_called) {} ~ClusterFunctionLibraryRuntime() override; @@ -51,6 +53,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { mutable mutex mu_; WorkerSession* const worker_session_ = nullptr; // not owned. + const bool create_worker_session_called_; struct FunctionData { const string graph_handle; diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc index 1810996ab8..6f96d7cb06 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc @@ -44,7 +44,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { std::unique_ptr<GraphMgr>())); cluster_flr_.reset( - new ClusterFunctionLibraryRuntime(worker_session_.get())); + new ClusterFunctionLibraryRuntime(worker_session_.get(), true)); } Status ConstructFunctionGraphHelper( diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index e0a5bb4c53..08020f0266 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -431,6 +431,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( const Part& part = partitions_[i]; Call* c = &calls[i]; c->req.set_session_handle(session_handle_); + c->req.set_create_worker_session_called(!should_deregister_); c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); *c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_debug_options() = @@ -587,6 +588,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper( c->req->set_is_last_partial_run(is_last_partial_run); } c->req->set_session_handle(session_handle_); + c->req->set_create_worker_session_called(!should_deregister_); c->req->set_graph_handle(part.graph_handle); c->req->set_step_id(step_id); *c->req->mutable_exec_opts() = exec_opts; @@ -1003,6 +1005,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() { if (!part.graph_handle.empty()) { Call* c = new Call; c->req.set_session_handle(session_handle_); + c->req.set_create_worker_session_called(!should_deregister_); c->req.set_graph_handle(part.graph_handle); // NOTE(mrry): We must capture `worker_cache_` since `this` // could be deleted before the callback is called. diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 18668b44d3..40bf564cab 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -282,10 +282,18 @@ const string& InMemoryRunGraphRequest::session_handle() const { return session_handle_; } +bool InMemoryRunGraphRequest::create_worker_session_called() const { + return create_worker_session_called_; +} + void InMemoryRunGraphRequest::set_session_handle(const string& handle) { session_handle_ = handle; } +void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) { + create_worker_session_called_ = called; +} + const string& InMemoryRunGraphRequest::graph_handle() const { return graph_handle_; } @@ -378,6 +386,8 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { if (!proto_version_) { proto_version_.reset(new RunGraphRequest); proto_version_->set_session_handle(session_handle()); + proto_version_->set_create_worker_session_called( + create_worker_session_called()); proto_version_->set_graph_handle(graph_handle()); proto_version_->set_step_id(step_id()); *proto_version_->mutable_exec_opts() = exec_opts(); @@ -403,6 +413,15 @@ void MutableProtoRunGraphRequest::set_session_handle(const string& handle) { request_.set_session_handle(handle); } +bool MutableProtoRunGraphRequest::create_worker_session_called() const { + return request_.create_worker_session_called(); +} + +void MutableProtoRunGraphRequest::set_create_worker_session_called( + bool called) { + request_.set_create_worker_session_called(called); +} + const string& MutableProtoRunGraphRequest::graph_handle() const { return request_.graph_handle(); } @@ -514,6 +533,10 @@ const string& ProtoRunGraphRequest::session_handle() const { return request_->session_handle(); } +bool ProtoRunGraphRequest::create_worker_session_called() const { + return request_->create_worker_session_called(); +} + const string& ProtoRunGraphRequest::graph_handle() const { return request_->graph_handle(); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 1f7cdb98a4..92c5668e3a 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -246,6 +246,9 @@ class RunGraphRequestWrapper { // namespace is used. virtual const string& session_handle() const = 0; + // Set to true if `CreateWorkerSession` was called for `session_handle`. + virtual bool create_worker_session_called() const = 0; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. virtual const string& graph_handle() const = 0; @@ -293,6 +296,7 @@ class RunGraphRequestWrapper { class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { public: virtual void set_session_handle(const string& handle) = 0; + virtual void set_create_worker_session_called(bool called) = 0; virtual void set_graph_handle(const string& handle) = 0; virtual void set_step_id(int64 step_id) = 0; virtual ExecutorOpts* mutable_exec_opts() = 0; @@ -317,6 +321,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { // RunGraphRequestWrapper methods. const string& session_handle() const override; const string& graph_handle() const override; + bool create_worker_session_called() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; size_t num_sends() const override; @@ -331,6 +336,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { // MutableRunGraphRequestWrapper methods. void set_session_handle(const string& handle) override; + void set_create_worker_session_called(bool called) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -347,6 +353,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { private: string session_handle_; + bool create_worker_session_called_; string graph_handle_; int64 step_id_; ExecutorOpts exec_opts_; @@ -370,6 +377,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. const string& session_handle() const override; + bool create_worker_session_called() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -385,6 +393,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { // MutableRunGraphRequestWrapper methods. void set_session_handle(const string& handle) override; + void set_create_worker_session_called(bool called) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -409,6 +418,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { // RunGraphRequestWrapper methods. const string& session_handle() const override; + bool create_worker_session_called() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index 51b9547f53..e51d63cf2b 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -98,20 +98,26 @@ Status SessionMgr::DeleteSession(const string& session) { return Status::OK(); } -std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSessionUnlocked( - const string& session) { - auto it = sessions_.find(session); - if (it == sessions_.end()) { - return legacy_session_; +Status SessionMgr::WorkerSessionForSessionLocked( + const string& session_handle, std::shared_ptr<WorkerSession>* out_session) { + if (session_handle.empty()) { + *out_session = legacy_session_; } else { - return it->second; + auto it = sessions_.find(session_handle); + if (it == sessions_.end()) { + return errors::Aborted("Session handle is not found: ", session_handle, + ". Possibly this worker just restarted."); + } else { + *out_session = it->second; + } } + return Status::OK(); } -std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSession( - const string& session) { +Status SessionMgr::WorkerSessionForSession( + const string& session_handle, std::shared_ptr<WorkerSession>* out_session) { mutex_lock l(mu_); - return WorkerSessionForSessionUnlocked(session); + return WorkerSessionForSessionLocked(session_handle, out_session); } std::shared_ptr<WorkerSession> SessionMgr::LegacySession() { diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 4c9702d522..0a10fe240f 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -50,7 +50,8 @@ class SessionMgr { bool isolate_session_state); // Locates the worker session for a given session handle - std::shared_ptr<WorkerSession> WorkerSessionForSession(const string& session); + Status WorkerSessionForSession(const string& session_handle, + std::shared_ptr<WorkerSession>* out_session); std::shared_ptr<WorkerSession> LegacySession(); Status DeleteSession(const string& session); @@ -86,8 +87,9 @@ class SessionMgr { const WorkerCacheFactory worker_cache_factory_; - std::shared_ptr<WorkerSession> WorkerSessionForSessionUnlocked( - const string& session) EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status WorkerSessionForSessionLocked( + const string& session_handle, std::shared_ptr<WorkerSession>* out_session) + EXCLUSIVE_LOCKS_REQUIRED(mu_); mutex mu_; // A map from session identifier to internal session structure. diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index 4d028f7f4a..858e636e08 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -46,8 +46,8 @@ class SessionMgrTest : public ::testing::Test { : device_(FakeDevice::MakeCPU( "/job:mnist/replica:0/task:0/device:fakecpu:0")), mgr_(&env_, "/job:mnist/replica:0/task:0", - std::unique_ptr<WorkerCacheInterface>(), factory_), - legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) { + std::unique_ptr<WorkerCacheInterface>(), factory_) { + TF_CHECK_OK(mgr_.WorkerSessionForSession("", &legacy_session_)); env_.local_devices = {device_.get()}; } @@ -69,7 +69,8 @@ TEST_F(SessionMgrTest, CreateSessionSimple) { string session_handle = "test_session_handle"; TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true)); - auto session = mgr_.WorkerSessionForSession(session_handle); + std::shared_ptr<WorkerSession> session; + TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session)); EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_NE(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); @@ -81,22 +82,26 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) { server_def.set_task_index(3); TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false)); - auto session_1 = mgr_.WorkerSessionForSession("handle_1"); + std::shared_ptr<WorkerSession> session_1; + TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_1", &session_1)); std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices(); EXPECT_EQ(1, devices_1.size()); TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false)); - auto session_2 = mgr_.WorkerSessionForSession("handle_2"); + std::shared_ptr<WorkerSession> session_2; + TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_2", &session_2)); std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices(); EXPECT_EQ(1, devices_2.size()); TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true)); - auto session_3 = mgr_.WorkerSessionForSession("handle_3"); + std::shared_ptr<WorkerSession> session_3; + TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_3", &session_3)); std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices(); EXPECT_EQ(1, devices_3.size()); TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true)); - auto session_4 = mgr_.WorkerSessionForSession("handle_4"); + std::shared_ptr<WorkerSession> session_4; + TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_4", &session_4)); std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices(); EXPECT_EQ(1, devices_4.size()); @@ -109,12 +114,23 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) { TEST_F(SessionMgrTest, LegacySession) { ServerDef server_def; string session_handle = ""; - auto session = mgr_.WorkerSessionForSession(session_handle); + std::shared_ptr<WorkerSession> session; + TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session)); EXPECT_EQ(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } +TEST_F(SessionMgrTest, UnknownSessionHandle) { + ServerDef server_def; + string session_handle = "unknown_session_handle"; + std::shared_ptr<WorkerSession> session; + Status s = mgr_.WorkerSessionForSession(session_handle, &session); + EXPECT_TRUE(errors::IsAborted(s)); + EXPECT_TRUE( + str_util::StrContains(s.error_message(), "Session handle is not found")); +} + TEST_F(SessionMgrTest, WorkerNameFromServerDef) { ServerDef server_def; server_def.set_job_name("worker"); @@ -124,7 +140,7 @@ TEST_F(SessionMgrTest, WorkerNameFromServerDef) { } TEST_F(SessionMgrTest, DeleteLegacySession) { - TF_EXPECT_OK(mgr_.DeleteSession("legacy_session")); + TF_EXPECT_OK(mgr_.DeleteSession("")); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 598652fb98..6b2536c3c0 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -59,21 +59,37 @@ void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) { - auto session = - env_->session_mgr->WorkerSessionForSession(request->session_handle()); - Status s = session->graph_mgr->Register( - request->session_handle(), request->graph_def(), request->graph_options(), - request->debug_options(), session->cluster_flr.get(), - response->mutable_graph_handle()); + std::shared_ptr<WorkerSession> session; + Status s; + if (request->create_worker_session_called()) { + s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), + &session); + } else { + session = env_->session_mgr->LegacySession(); + } + if (s.ok()) { + s = session->graph_mgr->Register( + request->session_handle(), request->graph_def(), + request->graph_options(), request->debug_options(), + session->cluster_flr.get(), response->mutable_graph_handle()); + } done(s); } void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, DeregisterGraphResponse* response, StatusCallback done) { - auto session = - env_->session_mgr->WorkerSessionForSession(request->session_handle()); - Status s = session->graph_mgr->Deregister(request->graph_handle()); + std::shared_ptr<WorkerSession> session; + Status s; + if (request->create_worker_session_called()) { + s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), + &session); + } else { + session = env_->session_mgr->LegacySession(); + } + if (s.ok()) { + s = session->graph_mgr->Deregister(request->graph_handle()); + } done(s); } @@ -135,11 +151,21 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, StatusCallback done) { const int64 step_id = request->step_id(); TRACEPRINTF("RunGraph: %lld", step_id); - auto session = - env_->session_mgr->WorkerSessionForSession(request->session_handle()); + std::shared_ptr<WorkerSession> session; + Status s; + if (request->create_worker_session_called()) { + s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), + &session); + } else { + session = env_->session_mgr->LegacySession(); + } + if (!s.ok()) { + done(s); + return; + } GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; - Status s = PrepareRunGraph(request, &in, out); + s = PrepareRunGraph(request, &in, out); if (!s.ok()) { delete out; done(s); @@ -209,12 +235,23 @@ void Worker::DoPartialRunGraph(CallOptions* opts, const int64 step_id = request->step_id(); const string& graph_handle = request->graph_handle(); TRACEPRINTF("PartialRunGraph: %lld", step_id); - auto session = - env_->session_mgr->WorkerSessionForSession(request->session_handle()); + std::shared_ptr<WorkerSession> session; + + Status s; + if (request->create_worker_session_called()) { + s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), + &session); + } else { + session = env_->session_mgr->LegacySession(); + } + if (!s.ok()) { + done(s); + return; + } GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; - Status s = PrepareRunGraph(request, &in, out); + s = PrepareRunGraph(request, &in, out); auto finish = [done, out, opts](const Status& s) { opts->ClearCancelCallback(); delete out; diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index cb7059b36e..18886babd5 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -97,6 +97,7 @@ WorkerSession::WorkerSession(const string& session_name, worker_cache(new WorkerFreeListCache(std::move(worker_cache))), device_mgr(std::move(device_mgr)), graph_mgr(std::move(graph_mgr)), - cluster_flr(new ClusterFunctionLibraryRuntime(this)) {} + cluster_flr( + new ClusterFunctionLibraryRuntime(this, !session_name.empty())) {} } // namespace tensorflow diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 3e7289bd91..1819a35248 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -103,6 +103,9 @@ message RegisterGraphRequest { // Subgraphs are scoped within one session. string session_handle = 1; + // Set to true if `CreateWorkerSession` was called for `session_handle`. + bool create_worker_session_called = 6; + // "graph_def" has the subgraph of nodes for this worker, with each node // having its device_name filled in. GraphDef graph_def = 2; @@ -144,6 +147,9 @@ message DeregisterGraphRequest { // empty, a single global namespace is used. string session_handle = 2; + // Set to true if `CreateWorkerSession` was called for `session_handle`. + bool create_worker_session_called = 3; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -200,6 +206,9 @@ message RunGraphRequest { // search for the graph_handle. string session_handle = 8; + // Set to true if `CreateWorkerSession` was called for `session_handle`. + bool create_worker_session_called = 10; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -234,7 +243,7 @@ message RunGraphRequest { // truncate long metadata messages. bool store_errors_in_response_body = 9; - // Next: 10 + // Next: 11 } message RunGraphResponse { |