diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-08-08 17:11:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 17:20:02 -0700 |
commit | 3325275eff98ffddb52a16db932481983a9de9a8 (patch) | |
tree | af5218ea5ea92288d66edade8150aab81f5c095f /tensorflow/core/distributed_runtime | |
parent | 3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (diff) |
Support keep alive so we can reclaim memory in the remote case.
PiperOrigin-RevId: 207971672
Diffstat (limited to 'tensorflow/core/distributed_runtime')
3 files changed, 118 insertions, 8 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 916c8720f0..b8af63724a 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -126,7 +126,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, do { context_id = random::New64(); } while (contexts_.find(context_id) != contexts_.end()); - contexts_.emplace(context_id, new ServerContext(std::move(ctx))); + contexts_.emplace( + context_id, + new ServerContext(std::move(ctx), request->keep_alive_secs(), env_)); } response->set_context_id(context_id); @@ -231,9 +233,11 @@ Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, KeepAliveResponse* response) { - // TODO(nareshmodi): Automated context_id cleaning is not implemented - return errors::Unimplemented( - "EagerServiceImpl::KeepAlive is not implemented."); + ServerContext* context = nullptr; + TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); + core::ScopedUnref context_unref(context); + + return Status::OK(); } Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, @@ -304,12 +308,15 @@ tensorflow::Status EagerServiceImpl::GetServerContext( *server_context = nullptr; return errors::InvalidArgument(strings::Printf( "Unable to find a context_id matching the specified one " - "(%lld). Perhaps the worker was restarted?", + "(%lld). Perhaps the worker was restarted, or the context was GC'd?", context_id)); } *server_context = iter->second; (*server_context)->Ref(); + + (*server_context)->RecordAccess(); + return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 718b4e2457..5723106aa6 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -38,8 +38,41 @@ namespace eager { // over this (e.g. gRPC). class EagerServiceImpl { public: - explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {} + explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { + gc_thread_.reset( + env_->env->StartThread({}, "EagerServiceContextGC", [this]() { + while (true) { + { + mutex_lock l(gc_thread_shutdown_mu_); + gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); + + if (shutting_down_) { + return; + } + } + { + mutex_lock l(contexts_mu_); + for (auto it = contexts_.begin(); it != contexts_.end();) { + if (it->second->IsStale()) { + it->second->Unref(); + it = contexts_.erase(it); + } else { + it++; + } + } + } + } + })); + } virtual ~EagerServiceImpl() { + { + mutex_lock l(gc_thread_shutdown_mu_); + shutting_down_ = true; + gc_thread_cv_.notify_all(); + } + gc_thread_.reset(); + + mutex_lock l(contexts_mu_); for (auto& entry : contexts_) { entry.second->Unref(); } @@ -71,8 +104,13 @@ class EagerServiceImpl { // and the EagerContext). class ServerContext : public core::RefCounted { public: - explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx) - : ctx_(std::move(ctx)) {} + explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx, + int64 destroy_after_secs, const WorkerEnv* env) + : ctx_(std::move(ctx)), env_(env) { + destroy_after_micros_ = + destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; + RecordAccess(); + } ~ServerContext() { for (const auto& entry : tensors_) { entry.second->Unref(); @@ -122,6 +160,18 @@ class EagerServiceImpl { return Status::OK(); } + void RecordAccess() { + mutex_lock l(last_accessed_mu_); + last_accessed_micros_ = env_->env->NowMicros(); + } + + bool IsStale() { + mutex_lock l(last_accessed_mu_); + return (destroy_after_micros_ <= 0 || + (env_->env->NowMicros() - last_accessed_micros_) > + destroy_after_micros_); + } + private: using RemoteTensorHandleMap = gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*, @@ -131,8 +181,15 @@ class EagerServiceImpl { // The context for this execution. std::unique_ptr<tensorflow::EagerContext> ctx_; + // The state related to the context for this execution. mutex tensors_mu_; RemoteTensorHandleMap tensors_ GUARDED_BY(tensors_mu_); + + const WorkerEnv* const env_; // Not owned. + + mutex last_accessed_mu_; + int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_); + int64 destroy_after_micros_; }; // The returned ServerContext will need to be Unrefed. tensorflow::Status GetServerContext(uint64, ServerContext**); @@ -145,6 +202,11 @@ class EagerServiceImpl { mutex contexts_mu_; std::unordered_map<uint64, ServerContext*> contexts_ GUARDED_BY(contexts_mu_); + std::unique_ptr<Thread> gc_thread_; + mutex gc_thread_shutdown_mu_; + condition_variable gc_thread_cv_; + bool shutting_down_ GUARDED_BY(gc_thread_shutdown_mu_) = false; + TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); }; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index d1f2a6da8f..5c9b33b345 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -365,6 +365,47 @@ TEST_F(EagerServiceImplTest, SendTensorTest) { &close_context_response)); } +TEST_F(EagerServiceImplTest, KeepAliveTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_rendezvous_id(random::New64()); + request.set_keep_alive_secs(3); + CreateContextResponse response; + + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + worker_env_.env->SleepForMicroseconds(5 * + tensorflow::EnvTime::kSecondsToMicros); + + KeepAliveRequest keep_alive_request; + KeepAliveResponse keep_alive_response; + + keep_alive_request.set_context_id(response.context_id()); + + Status status = + eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); + + EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", + status.error_message()); + + // Create a new context. + request.set_rendezvous_id(random::New64()); + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + // The context should not be GC'd. + worker_env_.env->SleepForMicroseconds(1 * + tensorflow::EnvTime::kSecondsToMicros); + + keep_alive_request.set_context_id(response.context_id()); + + TF_ASSERT_OK( + eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response)); +} + } // namespace } // namespace eager } // namespace tensorflow |