aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-21 14:24:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 14:28:09 -0700
commit39a66ecbe0f195625a83f6e7ccfc4b3e987c3bf4 (patch)
treea9840f38cc73e409f344a364ebe7398f91542ae5 /tensorflow/c/eager
parent25be72010a2e87e776814d2feb054d9ce43d7884 (diff)
Allow dynamic specification of clusters for eager remote execution.
PiperOrigin-RevId: 201586130
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD1
-rw-r--r--tensorflow/c/eager/c_api.cc28
-rw-r--r--tensorflow/c/eager/c_api_test.cc36
3 files changed, 48 insertions, 17 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 93d07135e1..37be52f57d 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -121,6 +121,7 @@ tf_cuda_library(
tf_cuda_cc_test(
name = "c_api_test",
+ size = "small",
srcs = [
"c_api_debug_test.cc",
"c_api_test.cc",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 6e4764bcbf..00b474fe86 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -108,7 +109,8 @@ tensorflow::Status GetAllRemoteDevices(
}
tensorflow::Status CreateRemoteContexts(
- const std::vector<string>& remote_workers,
+ const std::vector<string>& remote_workers, int64 rendezvous_id,
+ const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
@@ -116,12 +118,14 @@ tensorflow::Status CreateRemoteContexts(
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse response;
+ request.set_rendezvous_id(rendezvous_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
return tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
}
+ *request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
@@ -175,6 +179,8 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
+ int64 rendezvous_id = tensorflow::random::New64();
+
std::vector<string> remote_workers;
grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
remote_workers.erase(
@@ -193,12 +199,24 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
- LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
- remote_eager_workers.get(),
- opts->async, &remote_contexts));
+ LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
+ remote_workers, rendezvous_id, opts->server_def,
+ remote_eager_workers.get(), opts->async, &remote_contexts));
tensorflow::RemoteRendezvous* r =
- grpc_server->worker_env()->rendezvous_mgr->Find(0);
+ grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
+
+ auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
+ TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
+ session_name, opts->server_def, true));
+
+ std::shared_ptr<tensorflow::WorkerSession> worker_session;
+ TF_RETURN_IF_ERROR(
+ grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
+ session_name, &worker_session));
+
+ // Initialize remote tensor communication based on worker session.
+ TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
*ctx = new TFE_Context(opts->session_options.options, opts->policy,
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index cd035940ff..3504a8b5e7 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -143,7 +143,7 @@ void TestRemoteExecute(bool async) {
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
@@ -208,25 +208,31 @@ TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async) {
- tensorflow::ServerDef server_def = GetServerDef(2);
+ tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
+ std::unique_ptr<tensorflow::GrpcServer> worker_server1;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server1)
+ .ok());
+ ASSERT_TRUE(worker_server1->Start().ok());
- std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ server_def.set_task_index(2);
+ std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
- server_def, tensorflow::Env::Default(), &worker_server)
+ server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
- ASSERT_TRUE(worker_server->Start().ok());
+ ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -234,12 +240,16 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
- const char remote_device_name[] =
- "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
- // Handles are on task0, but op is on remote (task1).
- TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0);
- TFE_OpSetDevice(matmul, remote_device_name, status);
+ auto* h1_task2 =
+ TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Handles are on task0 (local), and task2, but op is on task1.
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
+ TFE_OpSetDevice(matmul, task1_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
@@ -265,6 +275,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
+ TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
@@ -276,7 +287,8 @@ void TestRemoteExecuteSilentCopies(bool async) {
TF_DeleteStatus(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
- worker_server.release();
+ worker_server1.release();
+ worker_server2.release();
}
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }