aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/c_api.cc21
-rw-r--r--tensorflow/c/eager/c_api.h1
-rw-r--r--tensorflow/c/eager/c_api_test.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc59
-rw-r--r--tensorflow/core/common_runtime/eager/context.h14
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc17
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.h68
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc41
-rw-r--r--tensorflow/python/eager/context.py11
9 files changed, 214 insertions, 26 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index d7073d8e05..dfb1c9a376 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices(
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id,
- const tensorflow::ServerDef& server_def,
+ int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
@@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
+ request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
@@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts(
}
tensorflow::Status UpdateTFE_ContextWithServerDef(
- const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
+ int keep_alive_secs, const tensorflow::ServerDef& server_def,
+ TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
- remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
- ctx->context.Async(), &remote_contexts));
+ remote_workers, rendezvous_id, keep_alive_secs, server_def,
+ remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
@@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr;
- ctx->context.InitializeRemote(
- std::move(server), std::move(remote_eager_workers),
- std::move(remote_device_mgr), remote_contexts, r, device_mgr);
+ ctx->context.InitializeRemote(std::move(server),
+ std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts,
+ r, device_mgr, keep_alive_secs);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status) {
@@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
- status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
+ status->status =
+ UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 092af45731..a0ebc6fa0a 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
// If the following is set, all servers identified by the
// ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status);
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 00a0a71fca..71d5f3613c 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
- TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
@@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
- TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
@@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
- TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
@@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
- TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index e5fe87fc37..5bdd547c7f 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -163,6 +163,13 @@ EagerContext::~EagerContext() {
server_.release();
}
+ {
+ mutex_lock l(keep_alive_thread_shutdown_mu_);
+ shutting_down_ = true;
+ keep_alive_thread_cv_.notify_all();
+ }
+ keep_alive_thread_.reset();
+
CloseRemoteContexts();
#endif
@@ -334,7 +341,9 @@ void EagerContext::InitializeRemote(
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
- DeviceMgr* local_device_mgr) {
+ DeviceMgr* local_device_mgr, int keep_alive_secs) {
+ mutex_lock l(remote_state_mu_);
+
if (!remote_contexts_.empty()) {
CloseRemoteContexts();
}
@@ -376,6 +385,54 @@ void EagerContext::InitializeRemote(
InitDeviceMapAndAsync();
ClearCaches();
+
+ keep_alive_secs_ = keep_alive_secs;
+
+ sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
+
+ // Only schedule a single closure.
+ if (keep_alive_thread_ == nullptr) {
+ keep_alive_thread_.reset(
+ env_->StartThread({}, "EagerKeepAliveThread", [this]() {
+ while (true) {
+ {
+ {
+ mutex_lock l(keep_alive_thread_shutdown_mu_);
+ keep_alive_thread_cv_.wait_for(
+ l, std::chrono::seconds(sleep_for_secs_));
+
+ if (shutting_down_) {
+ return;
+ }
+ }
+ {
+ mutex_lock l(remote_state_mu_);
+ if (keep_alive_secs_ > 0) {
+ {
+ for (const auto& worker_and_context_id : remote_contexts_) {
+ auto* client = remote_eager_workers_->GetClient(
+ worker_and_context_id.first);
+
+ eager::KeepAliveRequest* request =
+ new eager::KeepAliveRequest;
+ eager::KeepAliveResponse* response =
+ new eager::KeepAliveResponse;
+
+ request->set_context_id(worker_and_context_id.second);
+ client->KeepAliveAsync(
+ request, response,
+ [request, response](const Status& s) {
+ delete request;
+ delete response;
+ });
+ }
+ }
+ }
+ }
+ }
+ }
+ }));
+ }
}
#endif
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 3eea56b5e3..ebaf500bb3 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -180,7 +180,7 @@ class EagerContext {
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
- DeviceMgr* local_device_mgr);
+ DeviceMgr* local_device_mgr, int keep_alive_secs);
bool HasActiveRemoteContext(uint64 context_id) {
return active_remote_contexts_.find(context_id) !=
@@ -190,7 +190,7 @@ class EagerContext {
// If true, then tensors should be shipped across processes via the
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
- // instead (which in-turn use WorkerService.RecvTensor RPCs.
+ // instead (which in-turn use WorkerService.RecvTensor RPCs).
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
private:
@@ -263,10 +263,20 @@ class EagerContext {
std::unique_ptr<ServerInterface> server_;
std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
+ mutex remote_state_mu_;
+
gtl::FlatMap<string, uint64> remote_contexts_;
gtl::FlatSet<uint64> active_remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
+
+ int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
+ std::atomic<int> sleep_for_secs_;
+
+ std::unique_ptr<Thread> keep_alive_thread_;
+ mutex keep_alive_thread_shutdown_mu_;
+ condition_variable keep_alive_thread_cv_;
+ bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
#endif
bool use_send_tensor_rpc_;
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 916c8720f0..b8af63724a 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -126,7 +126,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
do {
context_id = random::New64();
} while (contexts_.find(context_id) != contexts_.end());
- contexts_.emplace(context_id, new ServerContext(std::move(ctx)));
+ contexts_.emplace(
+ context_id,
+ new ServerContext(std::move(ctx), request->keep_alive_secs(), env_));
}
response->set_context_id(context_id);
@@ -231,9 +233,11 @@ Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
KeepAliveResponse* response) {
- // TODO(nareshmodi): Automated context_id cleaning is not implemented
- return errors::Unimplemented(
- "EagerServiceImpl::KeepAlive is not implemented.");
+ ServerContext* context = nullptr;
+ TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
+ core::ScopedUnref context_unref(context);
+
+ return Status::OK();
}
Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
@@ -304,12 +308,15 @@ tensorflow::Status EagerServiceImpl::GetServerContext(
*server_context = nullptr;
return errors::InvalidArgument(strings::Printf(
"Unable to find a context_id matching the specified one "
- "(%lld). Perhaps the worker was restarted?",
+ "(%lld). Perhaps the worker was restarted, or the context was GC'd?",
context_id));
}
*server_context = iter->second;
(*server_context)->Ref();
+
+ (*server_context)->RecordAccess();
+
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index 718b4e2457..5723106aa6 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -38,8 +38,41 @@ namespace eager {
// over this (e.g. gRPC).
class EagerServiceImpl {
public:
- explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {}
+ explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {
+ gc_thread_.reset(
+ env_->env->StartThread({}, "EagerServiceContextGC", [this]() {
+ while (true) {
+ {
+ mutex_lock l(gc_thread_shutdown_mu_);
+ gc_thread_cv_.wait_for(l, std::chrono::seconds(1));
+
+ if (shutting_down_) {
+ return;
+ }
+ }
+ {
+ mutex_lock l(contexts_mu_);
+ for (auto it = contexts_.begin(); it != contexts_.end();) {
+ if (it->second->IsStale()) {
+ it->second->Unref();
+ it = contexts_.erase(it);
+ } else {
+ it++;
+ }
+ }
+ }
+ }
+ }));
+ }
virtual ~EagerServiceImpl() {
+ {
+ mutex_lock l(gc_thread_shutdown_mu_);
+ shutting_down_ = true;
+ gc_thread_cv_.notify_all();
+ }
+ gc_thread_.reset();
+
+ mutex_lock l(contexts_mu_);
for (auto& entry : contexts_) {
entry.second->Unref();
}
@@ -71,8 +104,13 @@ class EagerServiceImpl {
// and the EagerContext).
class ServerContext : public core::RefCounted {
public:
- explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx)
- : ctx_(std::move(ctx)) {}
+ explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx,
+ int64 destroy_after_secs, const WorkerEnv* env)
+ : ctx_(std::move(ctx)), env_(env) {
+ destroy_after_micros_ =
+ destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
+ RecordAccess();
+ }
~ServerContext() {
for (const auto& entry : tensors_) {
entry.second->Unref();
@@ -122,6 +160,18 @@ class EagerServiceImpl {
return Status::OK();
}
+ void RecordAccess() {
+ mutex_lock l(last_accessed_mu_);
+ last_accessed_micros_ = env_->env->NowMicros();
+ }
+
+ bool IsStale() {
+ mutex_lock l(last_accessed_mu_);
+ return (destroy_after_micros_ <= 0 ||
+ (env_->env->NowMicros() - last_accessed_micros_) >
+ destroy_after_micros_);
+ }
+
private:
using RemoteTensorHandleMap =
gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*,
@@ -131,8 +181,15 @@ class EagerServiceImpl {
// The context for this execution.
std::unique_ptr<tensorflow::EagerContext> ctx_;
+ // The state related to the context for this execution.
mutex tensors_mu_;
RemoteTensorHandleMap tensors_ GUARDED_BY(tensors_mu_);
+
+ const WorkerEnv* const env_; // Not owned.
+
+ mutex last_accessed_mu_;
+ int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_);
+ int64 destroy_after_micros_;
};
// The returned ServerContext will need to be Unrefed.
tensorflow::Status GetServerContext(uint64, ServerContext**);
@@ -145,6 +202,11 @@ class EagerServiceImpl {
mutex contexts_mu_;
std::unordered_map<uint64, ServerContext*> contexts_ GUARDED_BY(contexts_mu_);
+ std::unique_ptr<Thread> gc_thread_;
+ mutex gc_thread_shutdown_mu_;
+ condition_variable gc_thread_cv_;
+ bool shutting_down_ GUARDED_BY(gc_thread_shutdown_mu_) = false;
+
TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl);
};
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index d1f2a6da8f..5c9b33b345 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -365,6 +365,47 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
&close_context_response));
}
+TEST_F(EagerServiceImplTest, KeepAliveTest) {
+ TestEagerServiceImpl eager_service_impl(&worker_env_);
+
+ CreateContextRequest request;
+ request.mutable_server_def()->set_job_name("localhost");
+ request.mutable_server_def()->set_task_index(0);
+ request.set_rendezvous_id(random::New64());
+ request.set_keep_alive_secs(3);
+ CreateContextResponse response;
+
+ TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
+
+ worker_env_.env->SleepForMicroseconds(5 *
+ tensorflow::EnvTime::kSecondsToMicros);
+
+ KeepAliveRequest keep_alive_request;
+ KeepAliveResponse keep_alive_response;
+
+ keep_alive_request.set_context_id(response.context_id());
+
+ Status status =
+ eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
+
+ EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
+ EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
+ status.error_message());
+
+ // Create a new context.
+ request.set_rendezvous_id(random::New64());
+ TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
+
+ // The context should not be GC'd.
+ worker_env_.env->SleepForMicroseconds(1 *
+ tensorflow::EnvTime::kSecondsToMicros);
+
+ keep_alive_request.set_context_id(response.context_id());
+
+ TF_ASSERT_OK(
+ eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
+}
+
} // namespace
} // namespace eager
} // namespace tensorflow
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 09223c86d4..aa57ca03e6 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -265,7 +265,7 @@ class Context(object):
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
if self._server_def is not None:
server_def_str = self._server_def.SerializeToString()
- pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
+ pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600,
server_def_str)
self._initialize_devices()
@@ -275,7 +275,7 @@ class Context(object):
self.ones_rank_cache().flush()
self.zeros_cache().flush()
- def set_server_def(self, server_def):
+ def set_server_def(self, server_def, keep_alive_secs=600):
"""Allow setting a server_def on the context.
When a server def is replaced, it effectively clears a bunch of caches
@@ -285,6 +285,11 @@ class Context(object):
Args:
server_def: A tensorflow::ServerDef proto.
Enables execution on remote devices.
+ keep_alive_secs: Num. seconds after which the remote end will hang up.
+ As long as the client is still alive, the server state for the context
+ will be kept alive. If the client is killed (or there is some failure),
+ the server will clean up its context keep_alive_secs after the final RPC
+ it receives.
Raises:
ValueError: if server_def is None.
@@ -296,7 +301,7 @@ class Context(object):
else:
server_def_str = server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
- server_def_str)
+ keep_alive_secs, server_def_str)
# Clear all the caches in case there are remote tensors in them.
self._clear_caches()