diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-08-08 17:11:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 17:20:02 -0700 |
commit | 3325275eff98ffddb52a16db932481983a9de9a8 (patch) | |
tree | af5218ea5ea92288d66edade8150aab81f5c095f | |
parent | 3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (diff) |
Support keep alive so we can reclaim memory in the remote case.
PiperOrigin-RevId: 207971672
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 21 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 1 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 59 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 14 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/eager_service_impl.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/eager_service_impl.h | 68 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc | 41 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 11 |
9 files changed, 214 insertions, 26 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index d7073d8e05..dfb1c9a376 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector<string>& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts( } tensorflow::Status UpdateTFE_ContextWithServerDef( - const tensorflow::ServerDef& server_def, TFE_Context* ctx) { + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // Initialize remote eager workers. tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, server_def, remote_eager_workers.get(), - ctx->context.Async(), &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); @@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - ctx->context.InitializeRemote( - std::move(server), std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, r, device_mgr); + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status) { @@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - status->status = UpdateTFE_ContextWithServerDef(server_def, ctx); + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 092af45731..a0ebc6fa0a 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, // If the following is set, all servers identified by the // ServerDef must be up when the context is created. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 00a0a71fca..71d5f3613c 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); @@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); @@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const char remote_device_name[] = @@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { ASSERT_TRUE(s.ok()) << s.error_message(); ASSERT_TRUE(worker_server->Start().ok()); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Create a new tensor_handle. diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index e5fe87fc37..5bdd547c7f 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -163,6 +163,13 @@ EagerContext::~EagerContext() { server_.release(); } + { + mutex_lock l(keep_alive_thread_shutdown_mu_); + shutting_down_ = true; + keep_alive_thread_cv_.notify_all(); + } + keep_alive_thread_.reset(); + CloseRemoteContexts(); #endif @@ -334,7 +341,9 @@ void EagerContext::InitializeRemote( std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DeviceMgr> remote_device_manager, const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, - DeviceMgr* local_device_mgr) { + DeviceMgr* local_device_mgr, int keep_alive_secs) { + mutex_lock l(remote_state_mu_); + if (!remote_contexts_.empty()) { CloseRemoteContexts(); } @@ -376,6 +385,54 @@ void EagerContext::InitializeRemote( InitDeviceMapAndAsync(); ClearCaches(); + + keep_alive_secs_ = keep_alive_secs; + + sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2); + + // Only schedule a single closure. + if (keep_alive_thread_ == nullptr) { + keep_alive_thread_.reset( + env_->StartThread({}, "EagerKeepAliveThread", [this]() { + while (true) { + { + { + mutex_lock l(keep_alive_thread_shutdown_mu_); + keep_alive_thread_cv_.wait_for( + l, std::chrono::seconds(sleep_for_secs_)); + + if (shutting_down_) { + return; + } + } + { + mutex_lock l(remote_state_mu_); + if (keep_alive_secs_ > 0) { + { + for (const auto& worker_and_context_id : remote_contexts_) { + auto* client = remote_eager_workers_->GetClient( + worker_and_context_id.first); + + eager::KeepAliveRequest* request = + new eager::KeepAliveRequest; + eager::KeepAliveResponse* response = + new eager::KeepAliveResponse; + + request->set_context_id(worker_and_context_id.second); + client->KeepAliveAsync( + request, response, + [request, response](const Status& s) { + delete request; + delete response; + }); + } + } + } + } + } + } + })); + } } #endif diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3eea56b5e3..ebaf500bb3 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -180,7 +180,7 @@ class EagerContext { std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DeviceMgr> remote_device_manager, const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, - DeviceMgr* local_device_mgr); + DeviceMgr* local_device_mgr, int keep_alive_secs); bool HasActiveRemoteContext(uint64 context_id) { return active_remote_contexts_.find(context_id) != @@ -190,7 +190,7 @@ class EagerContext { // If true, then tensors should be shipped across processes via the // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used - // instead (which in-turn use WorkerService.RecvTensor RPCs. + // instead (which in-turn use WorkerService.RecvTensor RPCs). bool UseSendTensorRPC() { return use_send_tensor_rpc_; } private: @@ -263,10 +263,20 @@ class EagerContext { std::unique_ptr<ServerInterface> server_; std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; + mutex remote_state_mu_; + gtl::FlatMap<string, uint64> remote_contexts_; gtl::FlatSet<uint64> active_remote_contexts_; gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> device_to_client_cache_; + + int keep_alive_secs_ GUARDED_BY(remote_state_mu_); + std::atomic<int> sleep_for_secs_; + + std::unique_ptr<Thread> keep_alive_thread_; + mutex keep_alive_thread_shutdown_mu_; + condition_variable keep_alive_thread_cv_; + bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false; #endif bool use_send_tensor_rpc_; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 916c8720f0..b8af63724a 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -126,7 +126,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, do { context_id = random::New64(); } while (contexts_.find(context_id) != contexts_.end()); - contexts_.emplace(context_id, new ServerContext(std::move(ctx))); + contexts_.emplace( + context_id, + new ServerContext(std::move(ctx), request->keep_alive_secs(), env_)); } response->set_context_id(context_id); @@ -231,9 +233,11 @@ Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, KeepAliveResponse* response) { - // TODO(nareshmodi): Automated context_id cleaning is not implemented - return errors::Unimplemented( - "EagerServiceImpl::KeepAlive is not implemented."); + ServerContext* context = nullptr; + TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); + core::ScopedUnref context_unref(context); + + return Status::OK(); } Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, @@ -304,12 +308,15 @@ tensorflow::Status EagerServiceImpl::GetServerContext( *server_context = nullptr; return errors::InvalidArgument(strings::Printf( "Unable to find a context_id matching the specified one " - "(%lld). Perhaps the worker was restarted?", + "(%lld). Perhaps the worker was restarted, or the context was GC'd?", context_id)); } *server_context = iter->second; (*server_context)->Ref(); + + (*server_context)->RecordAccess(); + return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 718b4e2457..5723106aa6 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -38,8 +38,41 @@ namespace eager { // over this (e.g. gRPC). class EagerServiceImpl { public: - explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {} + explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { + gc_thread_.reset( + env_->env->StartThread({}, "EagerServiceContextGC", [this]() { + while (true) { + { + mutex_lock l(gc_thread_shutdown_mu_); + gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); + + if (shutting_down_) { + return; + } + } + { + mutex_lock l(contexts_mu_); + for (auto it = contexts_.begin(); it != contexts_.end();) { + if (it->second->IsStale()) { + it->second->Unref(); + it = contexts_.erase(it); + } else { + it++; + } + } + } + } + })); + } virtual ~EagerServiceImpl() { + { + mutex_lock l(gc_thread_shutdown_mu_); + shutting_down_ = true; + gc_thread_cv_.notify_all(); + } + gc_thread_.reset(); + + mutex_lock l(contexts_mu_); for (auto& entry : contexts_) { entry.second->Unref(); } @@ -71,8 +104,13 @@ class EagerServiceImpl { // and the EagerContext). class ServerContext : public core::RefCounted { public: - explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx) - : ctx_(std::move(ctx)) {} + explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx, + int64 destroy_after_secs, const WorkerEnv* env) + : ctx_(std::move(ctx)), env_(env) { + destroy_after_micros_ = + destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; + RecordAccess(); + } ~ServerContext() { for (const auto& entry : tensors_) { entry.second->Unref(); @@ -122,6 +160,18 @@ class EagerServiceImpl { return Status::OK(); } + void RecordAccess() { + mutex_lock l(last_accessed_mu_); + last_accessed_micros_ = env_->env->NowMicros(); + } + + bool IsStale() { + mutex_lock l(last_accessed_mu_); + return (destroy_after_micros_ <= 0 || + (env_->env->NowMicros() - last_accessed_micros_) > + destroy_after_micros_); + } + private: using RemoteTensorHandleMap = gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*, @@ -131,8 +181,15 @@ class EagerServiceImpl { // The context for this execution. std::unique_ptr<tensorflow::EagerContext> ctx_; + // The state related to the context for this execution. mutex tensors_mu_; RemoteTensorHandleMap tensors_ GUARDED_BY(tensors_mu_); + + const WorkerEnv* const env_; // Not owned. + + mutex last_accessed_mu_; + int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_); + int64 destroy_after_micros_; }; // The returned ServerContext will need to be Unrefed. tensorflow::Status GetServerContext(uint64, ServerContext**); @@ -145,6 +202,11 @@ class EagerServiceImpl { mutex contexts_mu_; std::unordered_map<uint64, ServerContext*> contexts_ GUARDED_BY(contexts_mu_); + std::unique_ptr<Thread> gc_thread_; + mutex gc_thread_shutdown_mu_; + condition_variable gc_thread_cv_; + bool shutting_down_ GUARDED_BY(gc_thread_shutdown_mu_) = false; + TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); }; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index d1f2a6da8f..5c9b33b345 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -365,6 +365,47 @@ TEST_F(EagerServiceImplTest, SendTensorTest) { &close_context_response)); } +TEST_F(EagerServiceImplTest, KeepAliveTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_rendezvous_id(random::New64()); + request.set_keep_alive_secs(3); + CreateContextResponse response; + + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + worker_env_.env->SleepForMicroseconds(5 * + tensorflow::EnvTime::kSecondsToMicros); + + KeepAliveRequest keep_alive_request; + KeepAliveResponse keep_alive_response; + + keep_alive_request.set_context_id(response.context_id()); + + Status status = + eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); + + EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", + status.error_message()); + + // Create a new context. + request.set_rendezvous_id(random::New64()); + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + // The context should not be GC'd. + worker_env_.env->SleepForMicroseconds(1 * + tensorflow::EnvTime::kSecondsToMicros); + + keep_alive_request.set_context_id(response.context_id()); + + TF_ASSERT_OK( + eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response)); +} + } // namespace } // namespace eager } // namespace tensorflow diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 09223c86d4..aa57ca03e6 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -265,7 +265,7 @@ class Context(object): pywrap_tensorflow.TFE_DeleteContextOptions(opts) if self._server_def is not None: server_def_str = self._server_def.SerializeToString() - pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, + pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600, server_def_str) self._initialize_devices() @@ -275,7 +275,7 @@ class Context(object): self.ones_rank_cache().flush() self.zeros_cache().flush() - def set_server_def(self, server_def): + def set_server_def(self, server_def, keep_alive_secs=600): """Allow setting a server_def on the context. When a server def is replaced, it effectively clears a bunch of caches @@ -285,6 +285,11 @@ class Context(object): Args: server_def: A tensorflow::ServerDef proto. Enables execution on remote devices. + keep_alive_secs: Num. seconds after which the remote end will hang up. + As long as the client is still alive, the server state for the context + will be kept alive. If the client is killed (or there is some failure), + the server will clean up its context keep_alive_secs after the final RPC + it receives. Raises: ValueError: if server_def is None. @@ -296,7 +301,7 @@ class Context(object): else: server_def_str = server_def.SerializeToString() pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, - server_def_str) + keep_alive_secs, server_def_str) # Clear all the caches in case there are remote tensors in them. self._clear_caches() |