diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.cc')
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 90 |
1 files changed, 89 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 6ab2d1ebf1..5bdd547c7f 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/util/env_var.h" @@ -46,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts, local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), + num_active_steps_(0), async_default_(async), env_(opts.env), use_send_tensor_rpc_(false) { @@ -161,6 +163,13 @@ EagerContext::~EagerContext() { server_.release(); } + { + mutex_lock l(keep_alive_thread_shutdown_mu_); + shutting_down_ = true; + keep_alive_thread_cv_.notify_all(); + } + keep_alive_thread_.reset(); + CloseRemoteContexts(); #endif @@ -194,6 +203,35 @@ Status EagerContext::FindDeviceByName(const string& name, Device** result) { return Status::OK(); } +void EagerContext::StartStep() { + mutex_lock ml(metadata_mu_); + num_active_steps_++; + if (step_container_ == nullptr) { + step_container_.reset( + new ScopedStepContainer(0, [this](const string& name) { + for (Device* device : devices_) { + device->resource_manager()->Cleanup(name).IgnoreError(); + } + })); + } +} + +void EagerContext::EndStep() { + mutex_lock ml(metadata_mu_); + num_active_steps_--; + if (num_active_steps_ == 0) { + step_container_.reset(); + } +} + +ScopedStepContainer* EagerContext::StepContainer() { + if (num_active_steps_.load() == 0) { + return nullptr; + } + mutex_lock ml(metadata_mu_); + return step_container_.get(); +} + Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { if (remote_device_manager_ == nullptr) return Status::OK(); #ifndef __ANDROID__ @@ -303,7 +341,9 @@ void EagerContext::InitializeRemote( std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DeviceMgr> remote_device_manager, const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, - DeviceMgr* local_device_mgr) { + DeviceMgr* local_device_mgr, int keep_alive_secs) { + mutex_lock l(remote_state_mu_); + if (!remote_contexts_.empty()) { CloseRemoteContexts(); } @@ -345,6 +385,54 @@ void EagerContext::InitializeRemote( InitDeviceMapAndAsync(); ClearCaches(); + + keep_alive_secs_ = keep_alive_secs; + + sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2); + + // Only schedule a single closure. + if (keep_alive_thread_ == nullptr) { + keep_alive_thread_.reset( + env_->StartThread({}, "EagerKeepAliveThread", [this]() { + while (true) { + { + { + mutex_lock l(keep_alive_thread_shutdown_mu_); + keep_alive_thread_cv_.wait_for( + l, std::chrono::seconds(sleep_for_secs_)); + + if (shutting_down_) { + return; + } + } + { + mutex_lock l(remote_state_mu_); + if (keep_alive_secs_ > 0) { + { + for (const auto& worker_and_context_id : remote_contexts_) { + auto* client = remote_eager_workers_->GetClient( + worker_and_context_id.first); + + eager::KeepAliveRequest* request = + new eager::KeepAliveRequest; + eager::KeepAliveResponse* response = + new eager::KeepAliveResponse; + + request->set_context_id(worker_and_context_id.second); + client->KeepAliveAsync( + request, response, + [request, response](const Status& s) { + delete request; + delete response; + }); + } + } + } + } + } + } + })); + } } #endif |