diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager/eager_service_impl.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/eager_service_impl.h | 68 |
1 files changed, 65 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 718b4e2457..2784c5d26e 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -38,8 +38,41 @@ namespace eager { // over this (e.g. gRPC). class EagerServiceImpl { public: - explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {} + explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { + gc_thread_.reset( + env_->env->StartThread({}, "EagerServiceContextGC", [this]() { + while (true) { + { + mutex_lock l(gc_thread_shutdown_mu_); + gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); + + if (shutting_down_) { + return; + } + } + { + mutex_lock l(contexts_mu_); + for (auto it = contexts_.begin(); it != contexts_.end();) { + if (it->second->IsStale()) { + it->second->Unref(); + it = contexts_.erase(it); + } else { + it++; + } + } + } + } + })); + } virtual ~EagerServiceImpl() { + { + mutex_lock l(gc_thread_shutdown_mu_); + shutting_down_ = true; + gc_thread_cv_.notify_all(); + } + gc_thread_.reset(); + + mutex_lock l(contexts_mu_); for (auto& entry : contexts_) { entry.second->Unref(); } @@ -71,8 +104,13 @@ class EagerServiceImpl { // and the EagerContext). class ServerContext : public core::RefCounted { public: - explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx) - : ctx_(std::move(ctx)) {} + explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx, + int64 destroy_after_secs, const WorkerEnv* env) + : ctx_(std::move(ctx)), env_(env) { + destroy_after_micros_ = + destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; + RecordAccess(); + } ~ServerContext() { for (const auto& entry : tensors_) { entry.second->Unref(); @@ -122,6 +160,18 @@ class EagerServiceImpl { return Status::OK(); } + void RecordAccess() { + mutex_lock l(last_accessed_mu_); + last_accessed_micros_ = env_->env->NowMicros(); + } + + bool IsStale() { + mutex_lock l(last_accessed_mu_); + return (destroy_after_micros_ > 0 && + (env_->env->NowMicros() - last_accessed_micros_) > + destroy_after_micros_); + } + private: using RemoteTensorHandleMap = gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*, @@ -131,8 +181,15 @@ class EagerServiceImpl { // The context for this execution. std::unique_ptr<tensorflow::EagerContext> ctx_; + // The state related to the context for this execution. mutex tensors_mu_; RemoteTensorHandleMap tensors_ GUARDED_BY(tensors_mu_); + + const WorkerEnv* const env_; // Not owned. + + mutex last_accessed_mu_; + int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_); + int64 destroy_after_micros_; }; // The returned ServerContext will need to be Unrefed. tensorflow::Status GetServerContext(uint64, ServerContext**); @@ -145,6 +202,11 @@ class EagerServiceImpl { mutex contexts_mu_; std::unordered_map<uint64, ServerContext*> contexts_ GUARDED_BY(contexts_mu_); + std::unique_ptr<Thread> gc_thread_; + mutex gc_thread_shutdown_mu_; + condition_variable gc_thread_cv_; + bool shutting_down_ GUARDED_BY(gc_thread_shutdown_mu_) = false; + TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); }; |