From 3325275eff98ffddb52a16db932481983a9de9a8 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 8 Aug 2018 17:11:32 -0700 Subject: Support keep alive so we can reclaim memory in the remote case. PiperOrigin-RevId: 207971672 --- tensorflow/c/eager/c_api.cc | 21 +++++++++++++-------- tensorflow/c/eager/c_api.h | 1 + tensorflow/c/eager/c_api_test.cc | 8 ++++---- 3 files changed, 18 insertions(+), 12 deletions(-) (limited to 'tensorflow/c') diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index d7073d8e05..dfb1c9a376 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts( } tensorflow::Status UpdateTFE_ContextWithServerDef( - const tensorflow::ServerDef& server_def, TFE_Context* ctx) { + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, server_def, remote_eager_workers.get(), - ctx->context.Async(), &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); @@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - ctx->context.InitializeRemote( - std::move(server), std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, r, device_mgr); + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status) { @@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - status->status = UpdateTFE_ContextWithServerDef(server_def, ctx); + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 092af45731..a0ebc6fa0a 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, // If the following is set, all servers identified by the // ServerDef must be up when the context is created. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 00a0a71fca..71d5f3613c 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); @@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); @@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const char remote_device_name[] = @@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { ASSERT_TRUE(s.ok()) << s.error_message(); ASSERT_TRUE(worker_server->Start().ok()); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Create a new tensor_handle. -- cgit v1.2.3