aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-19 10:25:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 10:28:19 -0700
commitc740b345e8c17cde0dd4691c7e240a065cb8c88c (patch)
treedd85bcff39031ec09de4507a335b541fb183adb4 /tensorflow/c/eager
parentccaf2ca02739792a8a8e50a95246f2db1197aa97 (diff)
Allow setting server def on the eager context, and add the eager service to the grpc_tensorflow_server.
PiperOrigin-RevId: 201198350
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD5
-rw-r--r--tensorflow/c/eager/c_api.cc48
-rw-r--r--tensorflow/c/eager/c_api_internal.h4
-rw-r--r--tensorflow/c/eager/c_api_test.cc18
4 files changed, 47 insertions, 28 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index f265da2c2c..93d07135e1 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -54,7 +54,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
- "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
@@ -93,10 +92,10 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
- "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
],
)
@@ -139,7 +138,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 81221c4078..55d9c26b0d 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -36,9 +36,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -147,46 +147,66 @@ tensorflow::Status CreateRemoteContexts(
tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TFE_Context** ctx) {
+ // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
+ // server object (which currently CHECK-fails) and we miss the error, instead,
+ // we log the error, and then return to allow the user to see the error
+ // message.
+#define LOG_AND_RETURN_IF_ERROR(...) \
+ do { \
+ const ::tensorflow::Status _status = (__VA_ARGS__); \
+ LOG(ERROR) << _status.error_message(); \
+ if (TF_PREDICT_FALSE(!_status.ok())) return _status; \
+ } while (0)
+
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
"/replica:0/task:", opts->server_def.task_index());
- std::unique_ptr<tensorflow::eager::EagerGrpcServer> server;
- TF_RETURN_IF_ERROR(
- tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server));
- TF_RETURN_IF_ERROR(server->Start());
+ std::unique_ptr<tensorflow::ServerInterface> server;
+ LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
+
+ tensorflow::GrpcServer* grpc_server =
+ dynamic_cast<tensorflow::GrpcServer*>(server.get());
+ if (grpc_server == nullptr) {
+ LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
+ "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
+ }
+
+ LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
std::vector<string> remote_workers;
- server->master_env()->worker_cache->ListWorkers(&remote_workers);
+ grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
remote_workers.erase(
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
- TF_RETURN_IF_ERROR(GetAllRemoteDevices(
- remote_workers, server->master_env()->worker_cache, &remote_device_mgr));
+ LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
+ remote_workers, grpc_server->master_env()->worker_cache,
+ &remote_device_mgr));
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
- server->channel_cache();
+ grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
- TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
- remote_eager_workers.get(),
- opts->async, &remote_contexts));
+ LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
+ remote_eager_workers.get(),
+ opts->async, &remote_contexts));
tensorflow::RemoteRendezvous* r =
- server->worker_env()->rendezvous_mgr->Find(0);
+ grpc_server->worker_env()->rendezvous_mgr->Find(0);
- auto* device_mgr = server->worker_env()->device_mgr;
+ auto* device_mgr = grpc_server->worker_env()->device_mgr;
*ctx = new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr, r, std::move(server),
std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts);
return tensorflow::Status::OK();
+#undef LOG_AND_RETURN_IF_ERROR
}
} // namespace
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 04a6efc47c..4c5077023d 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
@@ -78,7 +78,7 @@ struct TFE_Context {
TFE_ContextDevicePlacementPolicy default_policy, bool async,
tensorflow::DeviceMgr* local_device_mgr,
tensorflow::Rendezvous* rendezvous,
- std::unique_ptr<tensorflow::GrpcServer> server,
+ std::unique_ptr<tensorflow::ServerInterface> server,
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 992d1afd5f..1d71a78b75 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string.h>
#include "tensorflow/c/eager/c_api_test_util.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -132,10 +132,10 @@ void TestRemoteExecute(bool async) {
server_def.set_task_index(1);
- std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
- ASSERT_TRUE(
- tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
- .ok());
+ std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
@@ -215,10 +215,10 @@ void TestRemoteExecuteSilentCopies(bool async) {
server_def.set_task_index(1);
- std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
- ASSERT_TRUE(
- tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
- .ok());
+ std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();