diff options
Diffstat (limited to 'tensorflow/c/eager/c_api.cc')
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index d7073d8e05..dfb1c9a376 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector<string>& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts( } tensorflow::Status UpdateTFE_ContextWithServerDef( - const tensorflow::ServerDef& server_def, TFE_Context* ctx) { + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // Initialize remote eager workers. tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, server_def, remote_eager_workers.get(), - ctx->context.Async(), &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); @@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - ctx->context.InitializeRemote( - std::move(server), std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, r, device_mgr); + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status) { @@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - status->status = UpdateTFE_ContextWithServerDef(server_def, ctx); + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( |