diff options
author | 2018-06-21 14:24:48 -0700 | |
---|---|---|
committer | 2018-06-21 14:28:09 -0700 | |
commit | 39a66ecbe0f195625a83f6e7ccfc4b3e987c3bf4 (patch) | |
tree | a9840f38cc73e409f344a364ebe7398f91542ae5 /tensorflow/c/eager | |
parent | 25be72010a2e87e776814d2feb054d9ce43d7884 (diff) |
Allow dynamic specification of clusters for eager remote execution.
PiperOrigin-RevId: 201586130
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 28 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 36 |
3 files changed, 48 insertions, 17 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 93d07135e1..37be52f57d 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -121,6 +121,7 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", + size = "small", srcs = [ "c_api_debug_test.cc", "c_api_test.cc", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 6e4764bcbf..00b474fe86 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -108,7 +109,8 @@ tensorflow::Status GetAllRemoteDevices( } tensorflow::Status CreateRemoteContexts( - const std::vector<string>& remote_workers, + const std::vector<string>& remote_workers, int64 rendezvous_id, + const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -116,12 +118,14 @@ tensorflow::Status CreateRemoteContexts( tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextResponse response; + request.set_rendezvous_id(rendezvous_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) { return tensorflow::errors::InvalidArgument( "Unable to parse ", remote_worker, " as a device name"); } + *request.mutable_server_def() = server_def; request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); @@ -175,6 +179,8 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + int64 rendezvous_id = tensorflow::random::New64(); + std::vector<string> remote_workers; grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); remote_workers.erase( @@ -193,12 +199,24 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // Initialize remote eager workers. tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, - remote_eager_workers.get(), - opts->async, &remote_contexts)); + LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + remote_workers, rendezvous_id, opts->server_def, + remote_eager_workers.get(), opts->async, &remote_contexts)); tensorflow::RemoteRendezvous* r = - grpc_server->worker_env()->rendezvous_mgr->Find(0); + grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); + + auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); + TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( + session_name, opts->server_def, true)); + + std::shared_ptr<tensorflow::WorkerSession> worker_session; + TF_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); + + // Initialize remote tensor communication based on worker session. + TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); auto* device_mgr = grpc_server->worker_env()->device_mgr; *ctx = new TFE_Context(opts->session_options.options, opts->policy, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index cd035940ff..3504a8b5e7 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -143,7 +143,7 @@ void TestRemoteExecute(bool async) { TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1)); + TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -208,25 +208,31 @@ TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } void TestRemoteExecuteSilentCopies(bool async) { - tensorflow::ServerDef server_def = GetServerDef(2); + tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. string serialized = server_def.SerializeAsString(); server_def.set_task_index(1); + std::unique_ptr<tensorflow::GrpcServer> worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); - std::unique_ptr<tensorflow::GrpcServer> worker_server; + server_def.set_task_index(2); + std::unique_ptr<tensorflow::GrpcServer> worker_server2; ASSERT_TRUE(tensorflow::GrpcServer::Create( - server_def, tensorflow::Env::Default(), &worker_server) + server_def, tensorflow::Env::Default(), &worker_server2) .ok()); - ASSERT_TRUE(worker_server->Start().ok()); + ASSERT_TRUE(worker_server2->Start().ok()); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1)); + TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -234,12 +240,16 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); - const char remote_device_name[] = - "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; - // Handles are on task0, but op is on remote (task1). - TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); - TFE_OpSetDevice(matmul, remote_device_name, status); + auto* h1_task2 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Handles are on task0 (local), and task2, but op is on task1. + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); + TFE_OpSetDevice(matmul, task1_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; @@ -265,6 +275,7 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteTensorHandle(h0_task0); TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h1_task2); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteOp(matmul); @@ -276,7 +287,8 @@ void TestRemoteExecuteSilentCopies(bool async) { TF_DeleteStatus(status); // TODO(nareshmodi): Figure out how to correctly shut the server down. - worker_server.release(); + worker_server1.release(); + worker_server2.release(); } TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } |