aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/context.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.cc')
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc90
1 files changed, 89 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 6ab2d1ebf1..5bdd547c7f 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/util/env_var.h"
@@ -46,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
+ num_active_steps_(0),
async_default_(async),
env_(opts.env),
use_send_tensor_rpc_(false) {
@@ -161,6 +163,13 @@ EagerContext::~EagerContext() {
server_.release();
}
+ {
+ mutex_lock l(keep_alive_thread_shutdown_mu_);
+ shutting_down_ = true;
+ keep_alive_thread_cv_.notify_all();
+ }
+ keep_alive_thread_.reset();
+
CloseRemoteContexts();
#endif
@@ -194,6 +203,35 @@ Status EagerContext::FindDeviceByName(const string& name, Device** result) {
return Status::OK();
}
+void EagerContext::StartStep() {
+ mutex_lock ml(metadata_mu_);
+ num_active_steps_++;
+ if (step_container_ == nullptr) {
+ step_container_.reset(
+ new ScopedStepContainer(0, [this](const string& name) {
+ for (Device* device : devices_) {
+ device->resource_manager()->Cleanup(name).IgnoreError();
+ }
+ }));
+ }
+}
+
+void EagerContext::EndStep() {
+ mutex_lock ml(metadata_mu_);
+ num_active_steps_--;
+ if (num_active_steps_ == 0) {
+ step_container_.reset();
+ }
+}
+
+ScopedStepContainer* EagerContext::StepContainer() {
+ if (num_active_steps_.load() == 0) {
+ return nullptr;
+ }
+ mutex_lock ml(metadata_mu_);
+ return step_container_.get();
+}
+
Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
if (remote_device_manager_ == nullptr) return Status::OK();
#ifndef __ANDROID__
@@ -303,7 +341,9 @@ void EagerContext::InitializeRemote(
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
- DeviceMgr* local_device_mgr) {
+ DeviceMgr* local_device_mgr, int keep_alive_secs) {
+ mutex_lock l(remote_state_mu_);
+
if (!remote_contexts_.empty()) {
CloseRemoteContexts();
}
@@ -345,6 +385,54 @@ void EagerContext::InitializeRemote(
InitDeviceMapAndAsync();
ClearCaches();
+
+ keep_alive_secs_ = keep_alive_secs;
+
+ sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
+
+ // Only schedule a single closure.
+ if (keep_alive_thread_ == nullptr) {
+ keep_alive_thread_.reset(
+ env_->StartThread({}, "EagerKeepAliveThread", [this]() {
+ while (true) {
+ {
+ {
+ mutex_lock l(keep_alive_thread_shutdown_mu_);
+ keep_alive_thread_cv_.wait_for(
+ l, std::chrono::seconds(sleep_for_secs_));
+
+ if (shutting_down_) {
+ return;
+ }
+ }
+ {
+ mutex_lock l(remote_state_mu_);
+ if (keep_alive_secs_ > 0) {
+ {
+ for (const auto& worker_and_context_id : remote_contexts_) {
+ auto* client = remote_eager_workers_->GetClient(
+ worker_and_context_id.first);
+
+ eager::KeepAliveRequest* request =
+ new eager::KeepAliveRequest;
+ eager::KeepAliveResponse* response =
+ new eager::KeepAliveResponse;
+
+ request->set_context_id(worker_and_context_id.second);
+ client->KeepAliveAsync(
+ request, response,
+ [request, response](const Status& s) {
+ delete request;
+ delete response;
+ });
+ }
+ }
+ }
+ }
+ }
+ }
+ }));
+ }
}
#endif