diff options
author | 2018-06-19 10:25:10 -0700 | |
---|---|---|
committer | 2018-06-19 10:28:19 -0700 | |
commit | c740b345e8c17cde0dd4691c7e240a065cb8c88c (patch) | |
tree | dd85bcff39031ec09de4507a335b541fb183adb4 /tensorflow/c/eager | |
parent | ccaf2ca02739792a8a8e50a95246f2db1197aa97 (diff) |
Allow setting server def on the eager context, and add the eager service to the grpc_tensorflow_server.
PiperOrigin-RevId: 201198350
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 48 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_internal.h | 4 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 18 |
4 files changed, 47 insertions, 28 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f265da2c2c..93d07135e1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -54,7 +54,6 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", @@ -93,10 +92,10 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) @@ -139,7 +138,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 81221c4078..55d9c26b0d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" @@ -147,46 +147,66 @@ tensorflow::Status CreateRemoteContexts( tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TFE_Context** ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + LOG(ERROR) << _status.error_message(); \ + if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + string worker_name = tensorflow::strings::StrCat( "/job:", opts->server_def.job_name(), "/replica:0/task:", opts->server_def.task_index()); - std::unique_ptr<tensorflow::eager::EagerGrpcServer> server; - TF_RETURN_IF_ERROR( - tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); - TF_RETURN_IF_ERROR(server->Start()); + std::unique_ptr<tensorflow::ServerInterface> server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast<tensorflow::GrpcServer*>(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); std::vector<string> remote_workers; - server->master_env()->worker_cache->ListWorkers(&remote_workers); + grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); remote_workers.erase( std::remove(remote_workers.begin(), remote_workers.end(), worker_name), remote_workers.end()); std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr; - TF_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, grpc_server->master_env()->worker_cache, + &remote_device_mgr)); std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache = - server->channel_cache(); + grpc_server->channel_cache(); std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers( tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); // Initialize remote eager workers. tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; - TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, - remote_eager_workers.get(), - opts->async, &remote_contexts)); + LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); tensorflow::RemoteRendezvous* r = - server->worker_env()->rendezvous_mgr->Find(0); + grpc_server->worker_env()->rendezvous_mgr->Find(0); - auto* device_mgr = server->worker_env()->device_mgr; + auto* device_mgr = grpc_server->worker_env()->device_mgr; *ctx = new TFE_Context(opts->session_options.options, opts->policy, opts->async, device_mgr, r, std::move(server), std::move(remote_eager_workers), std::move(remote_device_mgr), remote_contexts); return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR } } // namespace diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 04a6efc47c..4c5077023d 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" @@ -78,7 +78,7 @@ struct TFE_Context { TFE_ContextDevicePlacementPolicy default_policy, bool async, tensorflow::DeviceMgr* local_device_mgr, tensorflow::Rendezvous* rendezvous, - std::unique_ptr<tensorflow::GrpcServer> server, + std::unique_ptr<tensorflow::ServerInterface> server, std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers, std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr, const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>& diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 992d1afd5f..1d71a78b75 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include <string.h> #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -132,10 +132,10 @@ void TestRemoteExecute(bool async) { server_def.set_task_index(1); - std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); + std::unique_ptr<tensorflow::GrpcServer> worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); ASSERT_TRUE(worker_server->Start().ok()); TF_Status* status = TF_NewStatus(); @@ -215,10 +215,10 @@ void TestRemoteExecuteSilentCopies(bool async) { server_def.set_task_index(1); - std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); + std::unique_ptr<tensorflow::GrpcServer> worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); ASSERT_TRUE(worker_server->Start().ok()); TF_Status* status = TF_NewStatus(); |